Skip to content

Commit

Permalink
Merge pull request d2l-ai#150 from astonzhang/rnn
Browse files Browse the repository at this point in the history
update utils
  • Loading branch information
Aston Zhang authored Jan 12, 2018
2 parents 9ba840c + 82d7e28 commit 0a632f4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,16 @@ def data_iter_consecutive(corpus_indices, batch_size, num_steps, ctx=None):
yield data, label


def grad_clipping(params, theta, ctx):
def grad_clipping(params, clipping_norm, ctx):
"""Gradient clipping."""
if theta is not None:
if clipping_norm is not None:
norm = nd.array([0.0], ctx)
for p in params:
norm += nd.sum(p.grad ** 2)
norm = nd.sqrt(norm).asscalar()
if norm > theta:
if norm > clipping_norm:
for p in params:
p.grad[:] *= theta / norm
p.grad[:] *= clipping_norm / norm


def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,
Expand All @@ -274,7 +274,7 @@ def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,


def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
learning_rate, clipping_theta, batch_size,
learning_rate, clipping_norm, batch_size,
pred_period, pred_len, seqs, get_params, get_inputs,
ctx, corpus_indices, idx_to_char, char_to_idx,
is_lstm=False):
Expand Down Expand Up @@ -321,7 +321,7 @@ def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
loss = softmax_cross_entropy(outputs, label)
loss.backward()

grad_clipping(params, clipping_theta, ctx)
grad_clipping(params, clipping_norm, ctx)
SGD(params, learning_rate)

train_loss += nd.sum(loss).asscalar()
Expand Down

0 comments on commit 0a632f4

Please sign in to comment.