Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Detail on WeightDrop class _setup() cuDNN RNN weight compacting issue & register_parameter() #51

Open
esvhd opened this issue Jun 2, 2018 · 7 comments

Comments

@esvhd
Copy link

esvhd commented Jun 2, 2018

Hi there,
cc @Smerity

Thanks for sharing the code first of all. I've been diving into the details and would really appreciate if you could share some insight into WeightDrop class' self._setup() method.

I have 2 questions.

  1. regarding the comment on the cuDNN RNN weight compacting issue, code here. Could anyone expand on what exactly this issue is?

  2. Why does the code delete parameters and registering them again by calling register_parameter()? code here

Thanks.

@LukeMathWalker
Copy link

I share @esvhd's request: I'd love to have more details on this issue.

@zplizzi
Copy link

zplizzi commented Jul 30, 2018

I can explain part 2 of the question, but would love an explanation of part 1.

Essentially in a forward pass of the network with WeightDrop, there needs to be two separate copies of each weight parameter. The first copy is weight_raw, which is equivalent to the normal weight variable. The second copy is a version of weight_raw, but with dropout applied. It is derived from weight_raw, but has to be named the same as the original weight parameter so that it'll be used in the downstream calculations (eg a linear layer expects a parameter named weight to use in the computation, so the weight-dropped parameter must have this name).

During backpropagation, the gradient is propagated through weight into weight_raw (see #8 for more details) so that weight_raw accumulates the gradient updates and forms the actual copy of the layer weights.

@esvhd
Copy link
Author

esvhd commented Jul 30, 2018

@zplizzi thanks for the clarification.

Let me make sure I got this right. The two separate copies of each weight parameters - the first, un-registered is needed by the forward pass, and the second, registered, is used in training model for applying dropout and computing gradients? During training, the registered weights are updated, then copied to the un-registered version for eval later?

So for the section of setup() code I linked to, it is simply making sure that the weights are registered when the network is initialised? Wouldn't weights such as these be registered by default?

Perhaps what I need to understand better is registered vs un-registered weights in pytorch...

Thanks.

@zplizzi
Copy link

zplizzi commented Jul 30, 2018

I think you mostly got it right. In the setup() code, it's actually changing the name of the weights - it's moving them from weight to weight_raw. That way there's a place to create a new version of weight (with dropout applied) in the _setweights() method.

I'm not exactly sure the significance of weight_raw being registered with register_parameter() and weight not being registered that way. I suspect that's a way to indicate that the weights should be updated in SGD.

@esvhd
Copy link
Author

esvhd commented Aug 3, 2018

Thanks @zplizzi , very helpful.

@akurniawan
Copy link

Hi @zplizzi thanks for the great explanation. I just tried the code in the new version of pytorch (1.0), sadly this code will no longer work as there is new parameter check on RNN internal calculation
AT_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");

I'm thinking to change the code by moving the parts where weight_raw is stored in module parameters and move it to something like dictionary inside of WeightDrop, in that way the calculation should still the same and we can mitigate the error. What do you think?

@homelifes
Copy link

homelifes commented Feb 1, 2019

@zplizzi I do get the point that we are using the dropped version in the forward pass, but when backpropagating we are updating the raw weights. However, in the forward propagation, PyTorch searches for the weights named "weight_hh" in the parameters it has i believe, and in that case it cannot find it because the name has changed. So how does it perform the forward prop on the dropout mask?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants