Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Trainer.get_optimizer_cls_and_kwargs to be overridden #31875

Merged
merged 5 commits into from
Jul 11, 2024

Conversation

apoorvkh
Copy link
Contributor

@apoorvkh apoorvkh commented Jul 10, 2024

What does this PR do?

Currently, Trainer builds an optimizer by loading the optimizer class and arguments from Trainer.get_optimizer_cls_and_kwargs in Trainer.create_optimizer:

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

However, this prevents the get_optimizer_cls_and_kwargs() function from being overridden. As a solution, I've changed it into an instance method (instead of a @staticmethod) and from Trainer.get_optimizer_cls_and_kwargs(args) to self.get_optimizer_cls_and_kwargs() in this PR. All existing functionality should remain as is, but this should now be extensible (if you subclass Trainer).

Note: I think this breaks the current tests, which expect get_optimizer_cls_and_kwargs to be a static method, e.g.

actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr and @SunMarc

@apoorvkh apoorvkh marked this pull request as ready for review July 10, 2024 03:20
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

Hi @apoorvkh, thanks for opening a PR!

Could you give some more details about how you'd like to use this method? I think it should be possible to override a staticmethod in a child class:

In [1]: class Foo:
   ...:     @staticmethod
   ...:     def foo(a, b):
   ...:         return a * b
   ...:
   ...:     def bar(self, c, d):
   ...:         return c ** d
   ...:
   ...:     def baz(self, e, f):
   ...:         return self.foo(e, f)
   ...:
   ...:
   ...: class Bar(Foo):
   ...:     @staticmethod
   ...:     def foo(a, b):
   ...:         return a + b
   ...:

In [2]: Foo.foo(2, 3)
Out[2]: 6

In [3]: Bar.foo(2, 3)
Out[3]: 5

In [4]: Foo().bar(2, 3)
Out[4]: 8

In [5]: Bar().bar(2, 3)
Out[5]: 8

In [6]: Foo().baz(2, 3)
Out[6]: 6

In [7]: Bar().baz(2, 3)
Out[7]: 5

but I might be missing what you're trying to do

@apoorvkh
Copy link
Contributor Author

apoorvkh commented Jul 10, 2024

Yes, can definitely elaborate:

Say I want to use HF Trainer with an arbitrary PyTorch optimizer (AdamW here just as an example). Then I should intuitively extend Trainer like:

class CustomOptimizerTrainer(Trainer):
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: HfTrainingArguments, model=None) -> tuple[type[torch.optim.Optimizer], dict[str, Any]]:
        optimizer = torch.optim.AdamW
        optimizer_kwargs = {
            "lr": 4e-3,
            "betas": (0.9, 0.999),
            "weight_decay": 0.05,
        }
        return optimizer, optimizer_kwargs

However, this won't take effect, because Trainer.create_optimizer hardcodes the Trainer class name when calling get_optimizer_cls_and_kwargs:

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

This is not self.get_optimizer_cls_and_kwargs and so CustomOptimizerTrainer.get_optimizer_cls_and_kwargs will never be called. I think the best fix is to change Trainer.get_optimizer_cls_and_kwargs to self.get_optimizer_cls_and_kwargs in the original source of Trainer.create_optimizer.

I also made get_optimizer_cls_and_kwargs an instance method instead of a static method, but that probably doesn't matter as much and can be reverted. It breaks the syntax of the tests.

Please let me know if that's clearer and if you agree! Thanks!

@amyeroberts
Copy link
Collaborator

@apoorvkh Thanks for taking the time for the detailed explanation! Yes, I think switching to optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) should be OK, and a simple solution.

I'd rather we didn't change the method to be an instance method - this is a breaking change which might affect many users downstream.

@apoorvkh
Copy link
Contributor Author

Okay, sounds good then! That makes this a very simple PR. Made those changes and all tests pass :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this change and making our objects more flexible!

All LGTM - let's just get a seoncd 👍 from @muellerzr or @SunMarc to confirm this is all OK in trainer-land

@apoorvkh
Copy link
Contributor Author

apoorvkh commented Jul 11, 2024

Thanks! I am also considering making another (simple, no breaking changes) PR to support generic PyTorch optimizers via TrainingArguments. Would you support that idea?

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me as well, adds some nice flexibility :)

@amyeroberts amyeroberts merged commit 574e68d into huggingface:main Jul 11, 2024
21 checks passed
@apoorvkh apoorvkh deleted the trainer-optimizer-cls-fix branch July 11, 2024 21:14
@apoorvkh
Copy link
Contributor Author

PR to support generic PyTorch optimizers via TrainingArguments

I was going to add support for something like

TrainingArguments(
    optim=torch.optim.AdamW,
    optim_args={
        "betas" : (0.9, 0.999),
        "eps" : 1e-08,
        "weight_decay" : 0.01
    }
)

(in addition to the existing functionality of optim and optim_args for optimizers implemented in Transformers)

But looks like TrainingArguments objects must be JSON serializable. I don't have another approach in mind that is as elegant. We could allow optim="AdamW" and then do optimizer = getattr(torch.optim, args.optim). But this is specific to the optimizers available in torch.optim and the optim argument is already crowded with strings corresponding to OptimizerNames. Unless someone has a better idea, I will leave this for now :(

amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
…gface#31875)

* Change `Trainer.get_optimizer_cls_and_kwargs` to `self.`

* Make `get_optimizer_cls_and_kwargs` an instance method

* Fixing typo

* Revert `get_optimizer_cls_and_kwargs` to staticmethod

* restore newline to trainer.py eof
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
…gface#31875)

* Change `Trainer.get_optimizer_cls_and_kwargs` to `self.`

* Make `get_optimizer_cls_and_kwargs` an instance method

* Fixing typo

* Revert `get_optimizer_cls_and_kwargs` to staticmethod

* restore newline to trainer.py eof
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
…gface#31875)

* Change `Trainer.get_optimizer_cls_and_kwargs` to `self.`

* Make `get_optimizer_cls_and_kwargs` an instance method

* Fixing typo

* Revert `get_optimizer_cls_and_kwargs` to staticmethod

* restore newline to trainer.py eof
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants