[...] = in progress
- Xavier initialization (general-purpose activation)
- He et. al initialization (ReLU activation)
- Orthogonal initialization
- SELU initialization (SELU activation)
[...]
[...]
[...]
# Xavier init
for m in model:
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal(m.weight)
# He et. al init
for m in model:
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal(m.weight)
# orthogonal init
for m in model:
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal(m.weight)
# SELU init
for m in model:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
nn.init.normal(m.weight, 0, sqrt(1. / n))
elif isinstance(m, nn.Linear):
n = m.out_features
nn.init.normal(m.weight, 0, sqrt(1. / n))
For BatchNorm
we initialize the weights to 1 and the biases to 0.
for m in model:
if isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1)
nn.init.constant(m.bias, 0)
- L2 Regularization: add L2 norm weight penalty to loss function.
- L1 Regularization: add L1 norm weight penalty to loss function.
- Orthogonal Regularization: apply a weight penalty of
|W*W.T - I|
to loss function. - Max Norm Constraint: clamp weight norm to less than a constant
W.norm(2) < c
.
[...]
[...]
[...]
If a hidden unit's weight vector's L2 norm L
ever gets bigger than a certain max value c
, multiply the weight vector by c/L
. Enforce it immediately after each weight vector update or after every X
gradient update.
# l2 reg
l2_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for W in model.parameters():
l2_loss = l2_loss + (0.5 * W.norm(2) ** 2)
# l1 reg
l1_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for W in model.parameters():
l1_loss = l1_loss + W.norm(1)
# orthogonal reg (might be incorrect)
orth_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for W in model.parameters():
W_reshaped = W.view(W.shape[0], -1)
sym = torch.mm(W_reshaped, torch.t(W_reshaped))
sym -= Variable(torch.eye(W_reshaped.shape[0]))
orth_loss = orth_loss + sym.sum()
# max norm constraint
def max_norm(model, max_val=3, eps=1e-8):
for name, param in model.named_parameters():
if 'bias' not in name:
norm = param.norm(2, dim=0, keepdim=True) ** 2
desired = torch.clamp(norm, 0, max_val)
param = param * (desired / (eps + norm))
[...]
[...]
- Learning Rate
- Batch Size
- Effect on Generalization
References
- Cyclical Learning Rates for Training Neural Networks
- SGDR: Stochastic Gradient Descent with Warm Restarts
- On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima
- Don't Decay the Learning Rate, Increase the Batch Size
- https://www.reddit.com/r/MachineLearning/comments/77dn96/r_171006451_understanding_generalization_and/dol2u23/
[...]