Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: use GenerationConfig as the basis for .generate() parametrization #20388

Merged
merged 20 commits into from
Dec 15, 2022

Conversation

gante
Copy link
Member

@gante gante commented Nov 22, 2022

What does this PR do?

This PR introduces generation_config as the main controller of .generate() calls.

In particular:

  1. It adds a from_model_config class method to GenerateConfig, to load a generation config from a (legacy) model config;
  2. Adds a generation_config argument to .generate(). If it is not passed, it will be loaded from a pre-determined sequence (check for generation_config.json -> if it fails, load from the model config);
  3. Because we always have a generation_config in .generate(), which holds all parametrization, gets rid of all local variables;
  4. ⚠️ Changes the arguments to generate() (and corresponding docstring) so as to exclude generate_config parameters (i.e. they were moved to **kwargs). This is mostly to avoid a massive docstring and list of arguments that make .generate() very messy at the moment -- GenerationConfig's docstring explains all the ways .generate() can be controlled, organized by type of manipulation, while .generate()'s docstring focuses on the API.

Notes: I've successfully run SLOW tests of GPT2 (which has a generate_config.json) and BART (which does not) against this PR.

@gante gante changed the title Generate: use GenerationConfig as the basis for parametrization Generate: use GenerationConfig as the basis for .generate() parametrization Nov 22, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante changed the title Generate: use GenerationConfig as the basis for .generate() parametrization Generate: use GenerationConfig as the basis for .generate() parametrization Nov 22, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the fact we start passing along a generation config instead of 100 kwargs, this feels very consistent with what happens in the model files.

I have no strong opinion on the documentation of model.generate. It's okay for me if we defer to GenerationConfig for the doc.

The main change I'd like to see is having the GenerationConfig be stored (if it exists on the Hub) when we download the model in from_pretrained: with the current implementation all the hub kwargs like revision, token etc are not passed along, and it doesn't feel right to have them on generate. When a user wants to use a non-standard GenerationConfig, they can use the from_pretrained method of that class and pass along those kwargs there, but for the default one, we should rely on what was passed in the call to ModelClass.from_pretrained.

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

Fully agree with @sgugger here.

Totally ok to just link to the GenerateConfig doc page -> think this make the docs online also cleaner actually.
Also I'd maybe rename generate_config to just config in generate or do you think this will cause confusion with the model's config?

@patrickvonplaten
Copy link
Contributor

Overall, this is a great improvement !

@gante gante marked this pull request as ready for review November 25, 2022 12:49
@gante
Copy link
Member Author

gante commented Nov 25, 2022

@sgugger @patrickvonplaten It is ready for review.

Major changes since the last review request:

  1. ModelClass.from_pretrained() pre-loads a generation_config attribute to the model if a generation_config.json exists, as suggested above
  2. Handle the case where the model config has nested dictionaries (e.g. a decoder component)
  3. Keep full retrocompatibility, including ad hoc model.config changes before calling GenerationMixin functions (that's why you'll see GenerationConfig.from_model_config in so many places, all those functions may be called independently 😓 )
  4. Add documentation and enhance examples

Also FYI, I'm off until the 8th of Dec 🌴

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! I only have a couple of comments. @patrickvonplaten if you could find some time to review this PR this week as you know the insides of generate better than me, that would be awesome!

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice.

Just one major thing:

I think we should generate the generation_config when doing from_pretrained(...) and __init__(...) of a model that is capable of using generate and then also directly save a generate_config.json. Otherwise GenerationConfig.from_model_config(self.config) is called over and over again in generate and people won't switch to using the generation config really IMO. Wdyt @sgugger @gante ?

Apart from this just left some nits.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2022

Agreed with you @patrickvonplaten , that's a very good idea!

gante and others added 3 commits December 13, 2022 19:51
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@gante gante force-pushed the generate_config_from_model_config branch from 5747b0e to 09f7ad3 Compare December 13, 2022 19:51
@gante
Copy link
Member Author

gante commented Dec 14, 2022

@sgugger @patrickvonplaten

Here is a summary of the key changes since your last review:

  • (thanks for the suggestion!) In model.from_pretrained, model.generation_config is set from the model config if the generation config doesn’t exist, effectively making all future generation-capable models hold a default generation config parameter. NOTE: This required minor legacy handling logic, for the case where the user makes ad hoc model config changes to control generation (which the previous solution intentionally accounted for)
  • added a default prepare_inputs_for_generation, which raises NotImplementedError, and updated the new can_generate check accordingly. Contrarily to @patrickvonplaten's suggestion, I've kept the _validate_model() check -- it returns an informative exception to the user if they try to generate with an incorrect class of a model with generation capabilities, like AutoModel.from_pretrained(“gpt2”). Not using the right class was a common source of issues in the past.
  • Improved the example to use named generation config files with an actual T5 example. I think two named generation configs would make the example too long 🤔 (cc @patrickvonplaten)

I was thinking of doing the following in a follow-up PR (to avoid adding more features to this already long PR that is blocking Arthur on Whisper work):

  • Add the needed modifications such that model.save_pretrained can push to the hub a default generation config if the file doesn’t yet exist, from the model.generation_config parameter (as @sgugger suggested)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great, thanks!

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member Author

gante commented Dec 15, 2022

@patrickvonplaten -- I'm merging to unblock @ArthurZucker's work on Whisper.

Comments to the points above are still helpful, and I can include them in a subsequent PR! :D

@gante gante merged commit 4bc723f into huggingface:main Dec 15, 2022
@gante gante deleted the generate_config_from_model_config branch December 15, 2022 18:27
gsarti added a commit to gsarti/transformers that referenced this pull request Dec 16, 2022
… add_get_encoder_decoder_fsmt

* 'main' of ssh://github.com/huggingface/transformers: (1433 commits)
  Add Universal Segmentation class + mapping (huggingface#20766)
  Stop calling expand_1d on newer TF versions (huggingface#20786)
  Fix object detection2 (huggingface#20798)
  [Pipeline] skip feature extraction test if in `IMAGE_PROCESSOR_MAPPING` (huggingface#20790)
  Recompile `apex` in `DeepSpeed` CI image (huggingface#20788)
  Move convert_to_rgb to image_transforms module (huggingface#20784)
  Generate: use `GenerationConfig` as the basis for `.generate()` parametrization (huggingface#20388)
  Install video dependency for pipeline CI (huggingface#20777)
  Fixing object detection with `layoutlm` (huggingface#20776)
  [Pipeline] fix failing bloom `pipeline` test (huggingface#20778)
  Patch for FlanT5-XXL 8bit support (huggingface#20760)
  Install vision for TF pipeline tests (huggingface#20771)
  Even more validation. (huggingface#20762)
  Add Swin backbone (huggingface#20769)
  Install `torch-tensorrt 1.3.0` for DeepSpeed CI (huggingface#20764)
  Replaces xxx_required with requires_backends (huggingface#20715)
  [CI-Test] Fixes but also skips the mT5 tests (huggingface#20755)
  Fix attribute error problem  (huggingface#20765)
  [Tests] Improve test_attention_outputs (huggingface#20701)
  Fix missing `()` in some usage of `is_flaky` (huggingface#20749)
  ...
@fxmarty
Copy link
Contributor

fxmarty commented Dec 29, 2022

The addition of can_generate() is breaking in Optimum, where we use generate() on models which do not inherit from PreTrainedModel. Why isn't can_generate() in GenerationMixin? Can a model inherit from GenerationMixin but not use generate()? cc @gante

@gante
Copy link
Member Author

gante commented Dec 29, 2022

@fxmarty can_generate() is called in PreTrainedModel at initialization time, to initialize the (new) generation config if it's a generation-compatible model. All models in transformers inherit GenerationMixin, regardless of whether they can generate, but in fact can_generate() is tangling the two classes at the moment, which is undesirable.

I may be able to rework this part, but I need to know -- what breaks on your end exactly?

@fxmarty
Copy link
Contributor

fxmarty commented Dec 29, 2022

All models in transformers inherit GenerationMixin

Yes thanks, I forgot this part!

The PR I linked fix the issue on our end. I think what is breaking is that generate() is no more usable on models that are not inheriting from PreTrainedModel or that don't redefine can_generate(), because of

if not self.can_generate():

But it's a very minor issue, and the fix is easy, so it's probably not too important.

@gante
Copy link
Member Author

gante commented Dec 29, 2022

@fxmarty 👍

In the long run, I'd like to see if it's possible to separate the two (PreTrainedModel and GenerationMixin, where a model only inherits GenerationMixin if it can generate). It should help libraries downstream like optimum!

Let me know if I can be of further assistance.

amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jan 4, 2023
…etrization (huggingface#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
venkat-natchi pushed a commit to venkat-natchi/transformers that referenced this pull request Jan 22, 2023
…etrization (huggingface#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
…etrization (huggingface#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants