Skip to content
This repository has been archived by the owner on Jan 4, 2023. It is now read-only.

Commit

Permalink
added details to init
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Feb 20, 2018
1 parent e8e53d7 commit 33381f4
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,66 @@
## Weight Initialization

PyTorch layers are initialized by default in their respective `reset_parameters()` method. For example:

- `nn.Linear`
- `weight` and `bias`: uniform distribution [-limit, +limit] where `limit` is `1. / sqrt(fan_in)` and `fan_in` is the number of input units in the weight tensor.
- `nn.Conv2D`
- `weight` and `bias`: uniform distribution [-limit, +limit] where `limit` is `1. / sqrt(fan_in)` and `fan_in` is the number of input units in the weight tensor.

**Note:** we usually initialize the biases to be zero since the asymmetry breaking is done by the weights ([source](http://cs231n.github.io/neural-networks-2/#init)).

#### Xavier Initialization

[...]
The above default initialization is not the best initialization strategy for symmetry breaking. Ideally, we want the variance of the outputs and the gradients to be the same in each layer.

To achieve this, we need an initialization from a uniform distribution with `Var(W) = 2 / (fan_in, fan_out)`. Some people choose to sample from a normal distribution, check [this discussion](https://github.com/keras-team/keras/issues/52) for empirical evidence that a uniform distribution gives better results.

Using the formula for the variance of a uniform distirbution, we derive that the weights should be initialized from a uniform distribution [-limit, +limit] where `limit` is `sqrt(6 / (fan_in + fan_out))`.

**Note:** the above derivation assumed zero-mean inputs and weights. This is not generally the case and may vary with the activation function used. Thus, this initialization is a general-purpose one meant to "work" pretty well in practice. Other initializations can be tailored to particular activations.

```python
# xavier init
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal(m.weight)
nn.init.xavier_uniform(m.weight)
```

- [arXiv](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)

#### He et. al Initialization

This is a similar initialization to Xavier tailored specifically for ReLU activations. Note that `fan_in` refers to the number of inputs to the layer.
This is a similarly derived initialization tailored specifically for ReLU activations since they do not exhibit zero mean.

```python
# he initialization
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal(m.weight, mode='fan_in')
```

- [arXiv](https://arxiv.org/abs/1502.01852)

#### SELU Initialization

Again, this initialization is specifically derived for the SELU activation function. The authors use the `fan_in` strategy.

```python
# selu init
for m in model.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
elif isinstance(m, nn.Linear):
fan_in = m.in_features
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
```

- [arXiv](https://arxiv.org/abs/1706.02515)

#### Orthogonal Initialization

Orthogonality is a desirable quality in convnet weights in part because it is norm preserving, i.e. it rotates the original matrix, but cannot change its norm (scale/shear). This property is valuable in deep or recurrent networks, where repeated matrix multiplication can result in signals vanishing or exploding.
Orthogonality is a desirable quality in NN weights in part because it is norm preserving, i.e. it rotates the input matrix, but cannot change its norm (scale/shear). This property is valuable in deep or recurrent networks, where repeated matrix multiplication can result in signals vanishing or exploding.

```python
for m in model.modules():
Expand All @@ -40,20 +74,6 @@ for m in model.modules():
- [Google+ Discussion](https://plus.google.com/+SoumithChintala/posts/RZfdrRQWL6u)
- [Reddit Discussion](https://www.reddit.com/r/MachineLearning/comments/2qsje7/how_do_you_initialize_your_neural_network_weights/)

#### SELU Initialization

```python
for m in model.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
nn.init.normal(m.weight, 0, sqrt(1. / n))
elif isinstance(m, nn.Linear):
n = m.in_features
nn.init.normal(m.weight, 0, sqrt(1. / n))
```

- [arXiv](https://arxiv.org/abs/1706.02515)

#### Batch Norm Initialization

```python
Expand Down

0 comments on commit 33381f4

Please sign in to comment.