🏷️sec_weight_decay
Now that we have characterized the problem of overfitting, we can introduce some standard techniques for regularizing models. Recall that we can always mitigate overfitting by going out and collecting more training data. That can be costly, time consuming, or entirely out of our control, making it impossible in the short run. For now, we can assume that we already have as much high-quality data as our resources permit and focus on regularization techniques.
Recall that in our
polynomial regression example
(:numref:sec_model_selection
)
we could limit our model's capacity
simply by tweaking the degree
of the fitted polynomial.
Indeed, limiting the number of features
is a popular technique to mitigate overfitting.
However, simply tossing aside features
can be too blunt an instrument for the job.
Sticking with the polynomial regression
example, consider what might happen
with high-dimensional inputs.
The natural extensions of polynomials
to multivariate data are called monomials,
which are simply products of powers of variables.
The degree of a monomial is the sum of the powers.
For example,
Note that the number of terms with degree
We have described
both the subsec_lin-algebra-norms
.
Weight decay (commonly called
One simple interpretation might be
to measure the complexity of a linear function
sec_linear_regression
for linear regression.
There, our loss was given by
Recall that
For
Moreover, you might ask why we work with the
One reason to work with the
Using the same notation in :eqref:eq_linreg_batch_update
,
the minibatch stochastic gradient descent updates
for
As before, we update
Whether we include a corresponding bias penalty
We can illustrate the benefits of weight decay through a simple synthetic example.
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
npx.set_np()
#@tab pytorch
%matplotlib inline
from d2l import torch as d2l
import torch
import torch.nn as nn
#@tab tensorflow
%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
First, we generate some data as before
We choose our label to be a linear function of our inputs,
corrupted by Gaussian noise with zero mean and standard deviation 0.01.
To make the effects of overfitting pronounced,
we can increase the dimensionality of our problem to
#@tab all
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = d2l.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
In the following, we will implement weight decay from scratch,
simply by adding the squared
First, we will define a function to randomly initialize our model parameters.
def init_params():
w = np.random.normal(scale=1, size=(num_inputs, 1))
b = np.zeros(1)
w.attach_grad()
b.attach_grad()
return [w, b]
#@tab pytorch
def init_params():
w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
return [w, b]
#@tab tensorflow
def init_params():
w = tf.Variable(tf.random.normal(mean=1, shape=(num_inputs, 1)))
b = tf.Variable(tf.zeros(shape=(1, )))
return [w, b]
Perhaps the most convenient way to implement this penalty is to square all terms in place and sum them up.
def l2_penalty(w):
return (w**2).sum() / 2
#@tab pytorch
def l2_penalty(w):
return torch.sum(w.pow(2)) / 2
#@tab tensorflow
def l2_penalty(w):
return tf.reduce_sum(tf.pow(w, 2)) / 2
The following code fits a model on the training set
and evaluates it on the test set.
The linear network and the squared loss
have not changed since :numref:chap_linear
,
so we will just import them via d2l.linreg
and d2l.squared_loss
.
The only change here is that our loss now includes the penalty term.
def train(lambd):
w, b = init_params()
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with autograd.record():
# The L2 norm penalty term has been added, and broadcasting
# makes `l2_penalty(w)` a vector whose length is `batch_size`
l = loss(net(X), y) + lambd * l2_penalty(w)
l.backward()
d2l.sgd([w, b], lr, batch_size)
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', np.linalg.norm(w))
#@tab pytorch
def train(lambd):
w, b = init_params()
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with torch.enable_grad():
# The L2 norm penalty term has been added, and broadcasting
# makes `l2_penalty(w)` a vector whose length is `batch_size`
l = loss(net(X), y) + lambd * l2_penalty(w)
l.sum().backward()
d2l.sgd([w, b], lr, batch_size)
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', torch.norm(w).item())
#@tab tensorflow
def train(lambd):
w, b = init_params()
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with tf.GradientTape() as tape:
# The L2 norm penalty term has been added, and broadcasting
# makes `l2_penalty(w)` a vector whose length is `batch_size`
l = loss(net(X), y) + lambd * l2_penalty(w)
grads = tape.gradient(l, [w, b])
d2l.sgd([w, b], grads, lr, batch_size)
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', tf.norm(w).numpy())
We now run this code with lambd = 0
,
disabling weight decay.
Note that we overfit badly,
decreasing the training error but not the
test error---a textook case of overfitting.
#@tab all
train(lambd=0)
Below, we run with substantial weight decay. Note that the training error increases but the test error decreases. This is precisely the effect we expect from regularization.
#@tab all
train(lambd=3)
Because weight decay is ubiquitous in neural network optimization, the deep learning framework makes it especially convenient, integrating weight decay into the optimization algorithm itself for easy use in combination with any loss function. Moreover, this integration serves a computational benefit, allowing implementation tricks to add weight decay to the algorithm, without any additional computational overhead. Since the weight decay portion of the update depends only on the current value of each parameter, the optimizer must touch each parameter once anyway.
:begin_tab:mxnet
In the following code, we specify
the weight decay hyperparameter directly
through wd
when instantiating our Trainer
.
By default, Gluon decays both
weights and biases simultaneously.
Note that the hyperparameter wd
will be multiplied by wd_mult
when updating model parameters.
Thus, if we set wd_mult
to zero,
the bias parameter
:begin_tab:pytorch
In the following code, we specify
the weight decay hyperparameter directly
through weight_decay
when instantiating our optimizer.
By default, PyTorch decays both
weights and biases simultaneously. Here we only set weight_decay
for
the weight, so the bias parameter
:begin_tab:tensorflow
In the following code, we create an wd
and apply it to the layer
through the kernel_regularizer
argument.
:end_tab:
def train_concise(wd):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
loss = gluon.loss.L2Loss()
num_epochs, lr = 100, 0.003
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': lr, 'wd': wd})
# The bias parameter has not decayed. Bias names generally end with "bias"
net.collect_params('.*bias').setattr('wd_mult', 0)
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
trainer.step(batch_size)
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', np.linalg.norm(net[0].weight.data()))
#@tab pytorch
def train_concise(wd):
net = nn.Sequential(nn.Linear(num_inputs, 1))
for param in net.parameters():
param.data.normal_()
loss = nn.MSELoss()
num_epochs, lr = 100, 0.003
# The bias parameter has not decayed
trainer = torch.optim.SGD([
{"params":net[0].weight,'weight_decay': wd},
{"params":net[0].bias}], lr=lr)
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with torch.enable_grad():
trainer.zero_grad()
l = loss(net(X), y)
l.backward()
trainer.step()
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', net[0].weight.norm().item())
#@tab tensorflow
def train_concise(wd):
net = tf.keras.models.Sequential()
net.add(tf.keras.layers.Dense(
1, kernel_regularizer=tf.keras.regularizers.l2(wd)))
net.build(input_shape=(1, num_inputs))
w, b = net.trainable_variables
loss = tf.keras.losses.MeanSquaredError()
num_epochs, lr = 100, 0.003
trainer = tf.keras.optimizers.SGD(learning_rate=lr)
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[5, num_epochs], legend=['train', 'test'])
for epoch in range(num_epochs):
for X, y in train_iter:
with tf.GradientTape() as tape:
# `tf.keras` requires retrieving and adding the losses from
# layers manually for custom training loop.
l = loss(net(X), y) + net.losses
grads = tape.gradient(l, net.trainable_variables)
trainer.apply_gradients(zip(grads, net.trainable_variables))
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L2 norm of w:', tf.norm(net.get_weights()[0]).numpy())
The plots look identical to those when we implemented weight decay from scratch. However, they run appreciably faster and are easier to implement, a benefit that will become more pronounced for larger problems.
#@tab all
train_concise(0)
#@tab all
train_concise(3)
So far, we only touched upon one notion of what constitutes a simple linear function. Moreover, what constitutes a simple nonlinear function can be an even more complex question. For instance, reproducing kernel Hilbert space (RKHS) allows one to apply tools introduced for linear functions in a nonlinear context. Unfortunately, RKHS-based algorithms tend to scale purely to large, high-dimensional data. In this book we will default to the simple heuristic of applying weight decay on all layers of a deep network.
- Regularization is a common method for dealing with overfitting. It adds a penalty term to the loss function on the training set to reduce the complexity of the learned model.
- One particular choice for keeping the model simple is weight decay using an
$L_2$ penalty. This leads to weight decay in the update steps of the learning algorithm. - The weight decay functionality is provided in optimizers from deep learning frameworks.
- Different sets of parameters can have different update behaviors within the same training loop.
- Experiment with the value of
$\lambda$ in the estimation problem in this section. Plot training and test accuracy as a function of$\lambda$ . What do you observe? - Use a validation set to find the optimal value of
$\lambda$ . Is it really the optimal value? Does this matter? - What would the update equations look like if instead of
$|\mathbf{w}|^2$ we used$\sum_i |w_i|$ as our penalty of choice ($L_1$ regularization)? - We know that
$|\mathbf{w}|^2 = \mathbf{w}^\top \mathbf{w}$ . Can you find a similar equation for matrices (see the Frobenius norm in :numref:subsec_lin-algebra-norms
)? - Review the relationship between training error and generalization error. In addition to weight decay, increased training, and the use of a model of suitable complexity, what other ways can you think of to deal with overfitting?
- In Bayesian statistics we use the product of prior and likelihood to arrive at a posterior via
$P(w \mid x) \propto P(x \mid w) P(w)$ . How can you identify$P(w)$ with regularization?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab: