-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve trainer API #10354
Improve trainer API #10354
Conversation
- The trainer and inferencer will load params from disk if param_path argument is not None in their constructor. - Remove params.py, we will expose core.Scope to the user if needed (e.g., for GAN). Currently we will not expose it, unless we clearly know doing so can support GAN. - Add `save_params` to Trainer (a TODO item). - rename "network" to "program"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent. It is a good skeleton to implement API.
@@ -100,23 +100,25 @@ def event_handler(event): | |||
word_dict, N)) | |||
|
|||
if avg_cost < 5.0: | |||
trainer.params.save(save_path) | |||
trainer.save_params(save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not only save_params
should be supported in the trainer, but also save_checkpoint
should be supported by the trainer to restore training progress.
We can add another PR to add this API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, thank you for pointing out!
Fixes: #10352
The trainer and inferencer will load params from disk if param_path
argument is not None in their constructor (API changed, implementation TODO)
Remove params.py, we will expose core.Scope to the user if needed
(e.g., for GAN). Currently we will not expose it, unless we clearly
know doing so can support GAN.
Add
save_params
to Trainer (API changed, implementation TODO)rename "network" to "program"