PyTorch layers are initialized by default in their respective reset_parameters()
method. For example:
nn.Linear
weight
andbias
: uniform distribution [-limit, +limit] wherelimit
is1. / sqrt(fan_in)
andfan_in
is the number of input units in the weight tensor.
nn.Conv2D
weight
andbias
: uniform distribution [-limit, +limit] wherelimit
is1. / sqrt(fan_in)
andfan_in
is the number of input units in the weight tensor.
With this implementation, the variance of the layer outputs is equal to Var(W) = 1 / 3 * sqrt(fan_in)
which isn't the best initialization strategy out there.
Note that PyTorch provides convenience functions for some of the initializations. The input and output shapes are computed using the method _calculate_fan_in_and_fan_out()
and a gain()
method scales the standard deviation to suit a particular activation.
The problem with random weight initialization is that the distribution of the outputs in a layer has a variance that grows linearly with the number of inputs. Ideally, we want to initialize the weights in a way that ensures good forward and backward data flow through the network. That is, we don't want the activations to be consistently shrinking or increasing as we progress through the different layers.
To achieve this, we need to initialize the weight vector W
of a layer from a uniform distribution with Var(W) = 2 / (fan_in + fan_out)
(add explanation later). Some people choose to sample from a normal distribution, check this discussion for empirical evidence that a uniform distribution gives better results.
Using the formula for the variance of a uniform distribution, we derive that the weights should be initialized from a uniform distribution [-limit, +limit] where limit
is sqrt(6) / sqrt(fan_in + fan_out)
.
Note that the above derivation assumed zero-mean inputs and weights. This is not generally the case and may vary with the activation function used. Thus, this initialization is a general-purpose one meant to "work" pretty well in practice. Other initializations can be tailored to particular activations.
It is also important to mention that we usually don't do some fancy initialization for the biases, but rather set them all to be zero (source).
# xavier init
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform(m.weight)
This is a similarly derived initialization tailored specifically for ReLU activations since they do not exhibit zero mean.
# he initialization
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal(m.weight, mode='fan_in')
Again, this initialization is specifically derived for the SELU activation function. The authors use the fan_in
strategy. They mention that there is no significant difference between sampling from a Gaussian, a truncated Gaussian or a Uniform distribution.
# selu init
for m in model.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
elif isinstance(m, nn.Linear):
fan_in = m.in_features
nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
Orthogonality is a desirable quality in NN weights in part because it is norm preserving, i.e. it rotates the input matrix, but cannot change its norm (scale/shear). This property is valuable in deep or recurrent networks, where repeated matrix multiplication can result in signals vanishing or exploding.
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal(m.weight)
for m in model:
if isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1)
nn.init.constant(m.bias, 0)
Heavily penalizes peaky weight vectors and encourages diffuse weight vectors. Has the appealing property of encouraging the network to use all of its inputs a little rather that some of its inputs a lot.
reg = 1e-6
l2_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
if 'bias' not in name:
l2_loss = l2_loss + (0.5 * reg * torch.sum(torch.pow(W, 2)))
Encourages sparsity, meaning we encourage the network to select the most useful inputs/features rather than use all.
reg = 1e-6
l1_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
if 'bias' not in name:
l1_loss = l1_loss + (reg * torch.sum(torch.abs(W)))
Improves gradient flow by keeping the matrix norm close to unitary.
reg = 1e-6
orth_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
if 'bias' not in name:
param_flat = param.view(param.shape[0], -1)
sym = torch.mm(param_flat, torch.t(param_flat))
sym -= Variable(torch.eye(param_flat.shape[0]))
orth_loss = orth_loss + (reg * sym.sum())
If a hidden unit's weight vector's L2 norm L
ever gets bigger than a certain max value c
, multiply the weight vector by c/L
. Enforce it immediately after each weight vector update or after every X
gradient update.
This constraint is another form of regularization. While L2 penalizes high weights using the loss function, "max norm" acts directly on the weights. L2 exerts a constant pressure to move the weights near zero which could throw away useful information when the loss function doesn't provide incentive for the weights to remain far from zero. On the other hand, "max norm" never drives the weights to near zero. As long as the norm is less than the constraint value, the constraint has no effect.
def max_norm(model, max_val=3, eps=1e-8):
for name, param in model.named_parameters():
if 'bias' not in name:
norm = param.norm(2, dim=0, keepdim=True)
desired = torch.clamp(norm, 0, max_val)
param = param * (desired / (eps + norm))
[...]
[...]
-
Learning Rate
-
Batch Size
-
Optimizer
-
Generalization
-
On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima
[...]