Skip to content

Commit

Permalink
[src] Bug-fix and improvements to stability for RNNLM code
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Aug 14, 2017
1 parent 377b786 commit 99113da
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/rnnlm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ OBJFILES = sampler.o arpa-sampling.o rnnlm-example.o rnnlm-example-utils.o \
LIBNAME = kaldi-rnnlm

ADDLIBS = ../nnet3/kaldi-nnet3.a ../cudamatrix/kaldi-cudamatrix.a \
../matrix/kaldi-matrix.a ../util/kaldi-util.a ../base/kaldi-base.a \
../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a \
../lm/kaldi-lm.a ../hmm/kaldi-hmm.a

include ../makefiles/default_rules.mk
6 changes: 4 additions & 2 deletions src/rnnlm/arpa-sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ void ArpaSampling::ConsumeNGram(const NGram& ngram) {
unigram_probs_.resize(static_cast<size_t>(word + 1), 0.0);
KALDI_ASSERT(unigram_probs_[word] == 0.0); // or repeated unigram.
unigram_probs_[word] = Exp(ngram.logprob);
if (ngram.backoff != 0.0)
higher_order_probs_[cur_order - 1][ngram.words].backoff_prob =
if (ngram.backoff != 0.0)
higher_order_probs_[cur_order - 1][ngram.words].backoff_prob =
Exp(ngram.backoff);
} else {
HistType history(ngram.words.begin(), ngram.words.end() - 1);
Expand Down Expand Up @@ -210,6 +210,8 @@ BaseFloat ArpaSampling::GetDistribution(
non_unigram_probs_out->insert(non_unigram_probs_out->end(),
non_unigram_probs_temp.begin(),
non_unigram_probs_temp.end());
std::sort(non_unigram_probs_out->begin(),
non_unigram_probs_out->end());
return ans;
}

Expand Down
8 changes: 5 additions & 3 deletions src/rnnlm/rnnlm-core-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ BaseFloat RnnlmCoreComputer::ProcessOutput(
output_deriv.Resize(output.NumRows(), output.NumCols());

BaseFloat weight, objf_num, objf_den, objf_den_exact;
ProcessRnnlmOutput(minibatch, derived, word_embedding,


RnnlmObjectiveOptions objective_opts; // Use the defaults; we're not training
// so they won't matter.
ProcessRnnlmOutput(objective_opts, minibatch, derived, word_embedding,
output, word_embedding_deriv, &output_deriv,
&weight, &objf_num, &objf_den,
&objf_den_exact);
Expand All @@ -120,5 +124,3 @@ BaseFloat RnnlmCoreComputer::ProcessOutput(

} // namespace rnnlm
} // namespace kaldi


5 changes: 4 additions & 1 deletion src/rnnlm/rnnlm-core-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ void ObjectiveTracker::PrintStatsOverall() const {


RnnlmCoreTrainer::RnnlmCoreTrainer(const RnnlmCoreTrainerOptions &config,
const RnnlmObjectiveOptions &objective_config,
nnet3::Nnet *nnet):
config_(config),
objective_config_(objective_config),
nnet_(nnet),
compiler_(*nnet), // for now we don't make available other optiosn
num_minibatches_processed_(0),
Expand Down Expand Up @@ -245,7 +247,8 @@ void RnnlmCoreTrainer::ProcessOutput(
output_deriv.Resize(output.NumRows(), output.NumCols());

BaseFloat weight, objf_num, objf_den, objf_den_exact;
ProcessRnnlmOutput(minibatch, derived, word_embedding,
ProcessRnnlmOutput(objective_config_,
minibatch, derived, word_embedding,
output, word_embedding_deriv, &output_deriv,
&weight, &objf_num, &objf_den,
&objf_den_exact);
Expand Down
2 changes: 2 additions & 0 deletions src/rnnlm/rnnlm-core-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class RnnlmCoreTrainer {
Will be modified each time you call Train().
*/
RnnlmCoreTrainer(const RnnlmCoreTrainerOptions &config,
const RnnlmObjectiveOptions &objective_config,
nnet3::Nnet *nnet);

/* Train on one minibatch.
Expand Down Expand Up @@ -187,6 +188,7 @@ class RnnlmCoreTrainer {
void UpdateParamsWithMaxChange();

const RnnlmCoreTrainerOptions config_;
const RnnlmObjectiveOptions objective_config_;
nnet3::Nnet *nnet_;
nnet3::Nnet *delta_nnet_; // nnet representing parameter-change for this
// minibatch (or, when using momentum, its moving
Expand Down
15 changes: 11 additions & 4 deletions src/rnnlm/rnnlm-example-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ void TestRnnlmTraining(const std::string &archive_rxfilename,

RnnlmCoreTrainerOptions core_config;
RnnlmEmbeddingTrainerOptions embedding_config;
RnnlmObjectiveOptions objective_config;

bool train_embedding = (RandInt(0, 1) == 0);

{
RnnlmTrainer trainer(train_embedding, core_config, embedding_config,
NULL, &embedding_mat, rnnlm);
objective_config, NULL, &embedding_mat, rnnlm);
for (; !reader.Done(); reader.Next()) {
trainer.Train(&reader.Value());
}
Expand Down Expand Up @@ -108,7 +109,9 @@ void TestRnnlmOutput(const std::string &archive_rxfilename) {

BaseFloat weight, objf_num, objf_den, objf_den_exact;

ProcessRnnlmOutput(example, derived, embedding, nnet_output,
RnnlmObjectiveOptions objective_config;
ProcessRnnlmOutput(objective_config,
example, derived, embedding, nnet_output,
train_embedding ? &embedding_deriv : NULL,
train_nnet ? &nnet_output_deriv : NULL,
&weight, &objf_num, &objf_den, &objf_den_exact);
Expand Down Expand Up @@ -141,7 +144,9 @@ void TestRnnlmOutput(const std::string &archive_rxfilename) {
<< ", smat sum is " << derived.output_words_smat.Sum();

BaseFloat weight2, objf_num2, objf_den2;
ProcessRnnlmOutput(example, derived, embedding2, nnet_output,
RnnlmObjectiveOptions objective_config;
ProcessRnnlmOutput(objective_config,
example, derived, embedding2, nnet_output,
NULL, NULL,
&weight2, &objf_num2, &objf_den2, NULL);
objf_change_observed(i) = (objf_num2 + objf_den2) -
Expand Down Expand Up @@ -178,7 +183,9 @@ void TestRnnlmOutput(const std::string &archive_rxfilename) {
<< ", smat sum is " << derived.output_words_smat.Sum();

BaseFloat weight2, objf_num2, objf_den2;
ProcessRnnlmOutput(example, derived, embedding, nnet_output2,
RnnlmObjectiveOptions objective_config;
ProcessRnnlmOutput(objective_config,
example, derived, embedding, nnet_output2,
NULL, NULL,
&weight2, &objf_num2, &objf_den2, NULL);
objf_change_observed(i) = (objf_num2 + objf_den2) -
Expand Down
69 changes: 58 additions & 11 deletions src/rnnlm/rnnlm-example-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void RnnlmExampleDerived::Swap(RnnlmExampleDerived *other) {
// This is called from ProcessRnnlmOutput() when we are doing importance
// sampling.
static void ProcessRnnlmOutputSampling(
const RnnlmObjectiveOptions &objective_config,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
Expand All @@ -82,6 +83,7 @@ static void ProcessRnnlmOutputSampling(
BaseFloat *objf_num,
BaseFloat *objf_den,
BaseFloat *objf_den_exact) {
KALDI_ASSERT(weight != NULL && objf_den != NULL); // Others are optional.

// In the case where minibatch.sample_group_size == 1, meaning for each 't' value we
// sample separately, num_sample_groups would equal the chunk_length and
Expand Down Expand Up @@ -162,11 +164,8 @@ static void ProcessRnnlmOutputSampling(

// The denominator part of this objective is something like:
// - \sum_i \sum_w output_weight(i) * q(i, w) * sample_inv_prob(w).
if (objf_den) {
*objf_den +=
-VecMatVec(output_weights_part, word_logprobs,
sample_inv_probs_part);
}
*objf_den += -VecMatVec(output_weights_part, word_logprobs,
sample_inv_probs_part);


// The derivative of the function q(l) = (l < 0 ? exp(l) : l + 1.0)
Expand All @@ -179,6 +178,26 @@ static void ProcessRnnlmOutputSampling(
// (which is what we're computing now).
word_logprobs.MulColsVec(sample_inv_probs_part);

if (objective_config.den_term_limit != 0.0) {
// If it's nonzero then check that it's negative, and not too close to zero,
// which would likely cause failure to converge. The default is 10.0.
KALDI_ASSERT(objective_config.den_term_limit < -0.5);
BaseFloat limit = objective_config.den_term_limit;
if (weight != NULL && objf_den != NULL && *weight > 0 &&
(*objf_den / *weight) < limit) {
// note: both things being divided below are negative, and
// 'scale' will be between zero and one.
BaseFloat scale = limit / (*objf_den / *weight);
// We scale the denominator part of the objective down by the inverse of
// the factor by which the denominator part of the objective exceeds the
// limit. This point in the code should only be reached on the first few
// iterations of training, or if there is some kind of instability,
// because the (*objf_den / *weight) will usually be close to zero,
// e.g. -0.01, while 'limit' is expected to be larger, like -10.0.
word_logprobs.Scale(scale);
}
}


// This adds -1.0 to the elements of 'word_logprobs' corresponding
// to the output words. This array 'word_logprobs' is going to
Expand Down Expand Up @@ -225,6 +244,7 @@ static void ProcessRnnlmOutputSampling(
// This is called from ProcessRnnlmOutput() when we are not doing importance
// sampling.
static void ProcessRnnlmOutputNoSampling(
const RnnlmObjectiveOptions &objective_config,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
Expand All @@ -235,6 +255,7 @@ static void ProcessRnnlmOutputNoSampling(
BaseFloat *objf_num,
BaseFloat *objf_den,
BaseFloat *objf_den_exact) {
KALDI_ASSERT(weight != NULL && objf_den != NULL); // Others are optional.

int32 embedding_dim = word_embedding.NumCols();
int32 num_words = word_embedding.NumRows();
Expand All @@ -247,9 +268,8 @@ static void ProcessRnnlmOutputNoSampling(
word_logprobs.AddMatMat(1.0, nnet_output, kNoTrans,
word_embedding, kTrans, 0.0);

if (weight) {
*weight = minibatch.output_weights.Sum();
}
*weight = minibatch.output_weights.Sum();

if (objf_num) {
*objf_num = TraceMatSmat(word_logprobs,
derived.output_words_smat, kTrans);
Expand Down Expand Up @@ -282,7 +302,8 @@ static void ProcessRnnlmOutputNoSampling(
// and some of these code paths will only be used in test code.
word_logprobs.ApplyExpSpecial();

if (objf_den) {
{ // This block computes *objf_den.

// we call this variable 'q_noeps' because in the math described in
// rnnlm-example-utils.h it is described as q(i,w), and because we're
// skipping over the epsilon symbol (which we don't want to include in the
Expand Down Expand Up @@ -318,6 +339,29 @@ static void ProcessRnnlmOutputNoSampling(

// Include the factor 'minibatch.output_weights'.
word_logprobs.MulRowsVec(minibatch.output_weights);



if (objective_config.den_term_limit != 0.0) {
// If it's nonzero then check that it's negative, and not too close to zero,
// which would likely cause failure to converge. The default is 10.0.
KALDI_ASSERT(objective_config.den_term_limit < -0.5);
BaseFloat limit = objective_config.den_term_limit;
if (weight != NULL && objf_den != NULL && *weight > 0 &&
(*objf_den / *weight) < limit) {
// note: both things being divided below are negative, and
// 'scale' will be between zero and one.
BaseFloat scale = limit / (*objf_den / *weight);
// We scale the denominator part of the objective down by the inverse of
// the factor by which the denominator part of the objective exceeds the
// limit. This point in the code should only be reached on the first few
// iterations of training, or if there is some kind of instability,
// because the (*objf_den / *weight) will usually be close to zero,
// e.g. -0.01, while 'limit' is expected to be larger, like -10.0.
word_logprobs.Scale(scale);
}
}

// After the following statement, 'word_logprobs' will contains the negative
// of the derivative of the objective function w.r.t. l(i, x), except that the
// first column (for epsilon) should be ignored.
Expand Down Expand Up @@ -353,6 +397,7 @@ static void ProcessRnnlmOutputNoSampling(


void ProcessRnnlmOutput(
const RnnlmObjectiveOptions &objective_config,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
Expand All @@ -371,12 +416,14 @@ void ProcessRnnlmOutput(

bool using_sampling = !(minibatch.sampled_words.empty());
if (using_sampling) {
ProcessRnnlmOutputSampling(minibatch, derived, word_embedding,
ProcessRnnlmOutputSampling(objective_config,
minibatch, derived, word_embedding,
nnet_output, word_embedding_deriv,
nnet_output_deriv, weight, objf_num,
objf_den, objf_den_exact);
} else {
ProcessRnnlmOutputNoSampling(minibatch, derived, word_embedding,
ProcessRnnlmOutputNoSampling(objective_config,
minibatch, derived, word_embedding,
nnet_output, word_embedding_deriv,
nnet_output_deriv, weight, objf_num,
objf_den, objf_den_exact);
Expand Down
34 changes: 28 additions & 6 deletions src/rnnlm/rnnlm-example-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,26 @@ void GetRnnlmExampleDerived(const RnnlmExample &minibatch,
bool need_embedding_deriv,
RnnlmExampleDerived *derived);

/**
Configuration class relating to the objective function used for RNNLM
training, more specifically for use by the function ProcessRnnlmOutputs().
*/
struct RnnlmObjectiveOptions {
BaseFloat den_term_limit;

RnnlmObjectiveOptions(): den_term_limit(-10.0) { }

void Register(OptionsItf *po) {
po->Register("den-term-limit", &den_term_limit,
"Modification to the with-sampling objective, that prevents "
"instability early in training, but in the end makes no difference. "
"We scale down the denominator part of the objective when the "
"average denominator part of the objective, for this minibatch, "
"is more negative than this value. Set this to 0.0 to use "
"unmodified objective function.");
}
};

/**
This function processes the output of the RNNLM computation for a single
minibatch; it outputs the objective-function contributions from the
Expand Down Expand Up @@ -178,7 +198,6 @@ void GetRnnlmExampleDerived(const RnnlmExample &minibatch,
closer bound to the natural normalizer term and helps avoid
instability in early phases of training.]
With importance sampling (if minibatch.sampled_words.size() > 0):
'den_term' equals
den_term(i) = 1.0 - (\sum_w q(w,i) * sample_inv_prob(w,i))
Expand Down Expand Up @@ -210,30 +229,33 @@ void GetRnnlmExampleDerived(const RnnlmExample &minibatch,
@param [out] nnet_output_dirv If non-NULL, the derivative of the
objective function w.r.t. 'nnet_output' is *added*
to this location.
@param [out] weight If non-NULL, the total weight over this
@param [out] weight Must be non-NULL. The total weight over this
minibatch will be *written to* here (will equal
minibatch.output_weights.Sum()).
@param [out] objf_num If non-NULL, the total numerator part of
the objective function will be written here, i.e.
the sum over i of weight(i) * num_term(i); see above
for details.
@param [out] objf_den If non-NULL, the total denominator part of
@param [out] objf_den Must be non-NULL. The total denominator part of
the objective function will be written here, i.e.
the sum over i of weight(i) * den_term(i); see above
for details. You add this to 'objf_num' to get the
total objective function.
@param [out] objf_den_exact If non-NULL, then if we're not
doing sampling (minibatch.sampled_words.empty()),
@param [out] objf_den_exact If non-NULL, and if we're not
doing sampling (i.e. if minibatch.sampled_words.empty()),
the 'exact' denominator part of the objective function
will be written here, i.e. the weighted sum of
exact_den_term(i) = -log(\sum_w p(i,w)).
If we are sampling, then there is no exact denominator
part, and this will be set to zero. This is provided
for diagnostic purposes. Derivatives will be computed
w.r.t. the objective consisting of
'objf_num + objf_den'.
'objf_num + objf_den', i.e. ignoring the 'exact' one.
For greatest efficiency you should probably not provide
this pointer.
*/
void ProcessRnnlmOutput(
const RnnlmObjectiveOptions &objective_opts,
const RnnlmExample &minibatch,
const RnnlmExampleDerived &derived,
const CuMatrixBase<BaseFloat> &word_embedding,
Expand Down
16 changes: 12 additions & 4 deletions src/rnnlm/rnnlm-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ namespace rnnlm {
RnnlmTrainer::RnnlmTrainer(bool train_embedding,
const RnnlmCoreTrainerOptions &core_config,
const RnnlmEmbeddingTrainerOptions &embedding_config,
const RnnlmObjectiveOptions &objective_config,
const CuSparseMatrix<BaseFloat> *word_feature_mat,
CuMatrix<BaseFloat> *embedding_mat,
nnet3::Nnet *rnnlm):
train_embedding_(train_embedding),
core_config_(core_config),
embedding_config_(embedding_config),
objective_config_(objective_config),
rnnlm_(rnnlm),
core_trainer_(NULL),
embedding_mat_(embedding_mat),
Expand All @@ -54,7 +56,7 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding,
<< "equal to embedding dimension " << embedding_dim
<< " but got " << rnnlm_input_dim << " and "
<< rnnlm_output_dim;
core_trainer_ = new RnnlmCoreTrainer(core_config, rnnlm_);
core_trainer_ = new RnnlmCoreTrainer(core_config_, objective_config_, rnnlm_);

if (train_embedding) {
embedding_trainer_ = new RnnlmEmbeddingTrainer(embedding_config,
Expand Down Expand Up @@ -167,9 +169,15 @@ void RnnlmTrainer::TrainWordEmbedding(
embedding_mat_->NumCols());
const CuSparseMatrix<BaseFloat> &word_features_trans =
(sampling ? active_word_features_trans_ : word_feature_mat_transpose_);
feature_embedding_deriv.AddMatSmat(1.0, *word_embedding_deriv,
word_features_trans, kTrans,
0.0);

feature_embedding_deriv.AddSmatMat(1.0, word_features_trans, kNoTrans,
*word_embedding_deriv, 0.0);

// TODO: eventually remove these lines.
KALDI_VLOG(3) << "word-features-trans sum is " << word_features_trans.Sum()
<< ", word-embedding-deriv-sum is " << word_embedding_deriv->Sum()
<< ", feature-embedding-deriv-sum is " << feature_embedding_deriv.Sum();

embedding_trainer_->Train(&feature_embedding_deriv);
}
}
Expand Down
Loading

0 comments on commit 99113da

Please sign in to comment.