Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Usage & concept questions #18

Closed
ChongWu-Biostat opened this issue Oct 27, 2019 · 3 comments
Closed

Usage & concept questions #18

ChongWu-Biostat opened this issue Oct 27, 2019 · 3 comments
Labels
question Further information is requested

Comments

@ChongWu-Biostat
Copy link

It works perfectly with me. Thank you for sharing and developing this repo. I think this idea really works (at least for my problem).

Thanks,
Chong

@ChongWu-Biostat
Copy link
Author

Just a quick question:
Can you explain and provide some guidance about the parameters?

wd_dict = get_weight_decays(model) # {'lstm_1/recurrent:0':0, 'output/kernel:0':0}
weight_decays = fill_dict_in_order(wd_dict,[4e-4,1e-4]) # {'lstm_1/recurrent:0':4e-4,'output/kernel:0':1e-4}
lr_multipliers = {'lstm_1':0.5}

optimizer = AdamW(lr=1e-4, weight_decays=weight_decays, lr_multipliers=lr_multipliers,
use_cosine_annealing=True, total_iterations=24)

If I understand correctly, weight_decays is similar to the L2 penalty. What is lr_multipliers really stands for? Do we have to give it a same name as input ("lstm_1")?
What's the total iteration (total_iterations) mean?

use_cosine_annealing means we use a large learning rate after some time, right?

Thank you for your help. I think your repo is way better than any other adamw version in Keras.

@OverLordGoldDragon
Copy link
Owner

OverLordGoldDragon commented Oct 27, 2019

@ChongWu-Biostat You're welcome, glad you find it useful.

Suppose I'll make a more detailed example to explain in case the README didn't suffice, but for now I'll respond to your questions:


Weight decays vs. L2 penalty

The key difference between weight_decays and L2-penalty is, latter's included in gradient and loss computation, former isn't (image from paper below). Turns out, latter is not desirable, as the L2 penalty gets included in momentum and RMS (rmsprop) computations, which:

  • Couples (forces dependency) between learning rate and lambda (weight decay)
  • Makes weight decay less effective for weight matrices with large gradients
  • Makes weight decay inconsistent across iterations depending on gradients

By fixing weight decay rate and separating it from loss, all of the above are remedied.


How to use lr_multipliers?

Suppose you have a model: Input -> Conv1D -> Conv1D -> LSTM -> Dense, and you've pretrained the Conv1D layers for feature extraction, and want to use an LSTM as an additional layer. If you use the same learning rate for all layers, your Conv1D may overfit - and a good workaround is to set per-layer lr, which could look something like:

  • 1e-4 -> 1e-4 -> 1e-3 -> 1e-3

(Input has no lr). So, pretrained layers' lr is 10x less. To achieve this, lr_multipliers detects layers by names specified in lr_multipliers dictionary keys, and applies the multipliers specified in their values to each of the layers. Example:

  • learning_rate=1e-3; lr_multipliers = {'conv1d_1':0.1, 'conv1d_2':0.1}

Names don't have to match exactly; substrings work also: {'conv1d':0.1} will apply 0.1 multiplier to every layer whose name contains the substring 'conv1d'.


What is cosine annealing?

lr gets multiplied according to the function below, whose interval (max-to-min # of iterations) is defined by total_iterations: (T_i below is total_iterations)

For example, at approx. iterations=11, we have eta_t=0.5, so if your lr=1e-3, it becomes 5e-4.

@OverLordGoldDragon OverLordGoldDragon added the question Further information is requested label Oct 27, 2019
@ChongWu-Biostat
Copy link
Author

Got it. Thank you for your explanation. I understand it now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants