Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support imperative learning rate scheduler #15584

Merged
merged 19 commits into from
Mar 31, 2019

Conversation

velconia
Copy link
Collaborator

@velconia velconia commented Jan 29, 2019

  1. move imperative mnist ut to test_imperative_mnist
  2. add optimizer lr scheduler ut into test_imperative_optimizer.py
  3. implement lr scheduler in imperative mode

__all__ = ['PiecewiseDecay']


class LearningRateDecay(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this support static graph? why not use existing lr_scheduler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static lr and dynamic lr is used via interface in learning_rate_scheduler.py under the fluid dir

# create learning rate Variable
if isinstance(self._learning_rate, float):
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be created many times?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this should only be called once Optimizer.init

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

persistable=True)
# get learning rate Variable from LearningRateDecay
elif isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate_map[framework.default_main_program(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use main program?

Copy link
Collaborator Author

@velconia velconia Mar 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to keep the same with static mode code

@velconia velconia merged commit d8d73ff into PaddlePaddle:develop Mar 31, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants