-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix usage of head masks by PT encoder-decoder models' generate()
function
#11621
Fix usage of head masks by PT encoder-decoder models' generate()
function
#11621
Conversation
* Add head_mask, decoder_head_mask and cross_attn_head_mask into prepare_inputs_for_generation for generate() function for multiple encoder-decoder models.
Hey @stancld, Thanks a lot for this contribution! Could we add one test to verify that generation works with head_mask for all encoder-decoder models? I think it could be added to |
Hey @patrickvonplaten, I've added one test. At this moment, there are two little issues I'm gonna handle later today so that all encoder-decoder models will pass this new test. |
generate()
functiongenerate()
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, if Prophetnet doesn't work with head_masking + generate, I'm totally fine with leaving it out for ProphetNet - we could then just overwrite the test to not run for Prophenet. The model is very unique and also doesn't work fully at the moment in general
Hi @patrickvonplaten, sorry for being silent for a while as I've been a bit too busy. As you suggest, I skip the test for |
generate()
functiongenerate()
function
…nction (huggingface#11621) * Add missing head masking for generate() function * Add head_mask, decoder_head_mask and cross_attn_head_mask into prepare_inputs_for_generation for generate() function for multiple encoder-decoder models. * Add test_genereate_with_head_masking * [WIP] Update the new test and handle special cases * make style * Omit ProphetNet test so far * make fix-copies
This PR adds missing arguments
head_mask
,decoder_head_mask
andcross_attn_head_mask
intoprepare_inputs_for_generation
function of PyTorch encoder-decoder models so that these args will be used during the generation whengenerate()
function is called.EDIT: Need to fix the new test for ProphetNet
Example
Behaviour before the PR:
- >>> 'The Eiffel Tower in Paris has been officially opened to the public.'
Behaviour after the PR:
+ >>> 'The Eiffel Tower in Paris has been officially opened to the public for the first time since it was completed in 1903.'
Reviewers: @patrickvonplaten