Skip to content

Commit

Permalink
[src] Make RNNLM training single threaded (workaround for CuSparse li…
Browse files Browse the repository at this point in the history
…brary bugs) (kaldi-asr#2520)
  • Loading branch information
hainan-xv authored and danpovey committed Jun 27, 2018
1 parent 60141df commit 8ce3a95
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 152 deletions.
109 changes: 30 additions & 79 deletions src/rnnlm/rnnlm-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding,
embedding_trainer_(NULL),
word_feature_mat_(word_feature_mat),
num_minibatches_processed_(0),
end_of_input_(false),
previous_minibatch_empty_(1),
current_minibatch_empty_(1),
srand_seed_(RandInt(0, 100000)) {


Expand Down Expand Up @@ -75,13 +72,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding,
<< embedding_mat_->NumRows() << " (mismatch).";
}
}

// Start a thread that calls run_background_thread(this).
// That thread will be responsible for computing derived variables of
// the minibatch, since that can be done independently of the main
// training process.
background_thread_ = std::thread(run_background_thread, this);

}


Expand All @@ -92,25 +82,40 @@ void RnnlmTrainer::Train(RnnlmExample *minibatch) {
<< VocabSize() << ", got "
<< minibatch->vocab_size;

// hand over 'minibatch' to the background thread to have its derived variable
// computed, via the class variable 'current_minibatch_'.
current_minibatch_empty_.Wait();
current_minibatch_.Swap(minibatch);
current_minibatch_full_.Signal();
num_minibatches_processed_++;
if (num_minibatches_processed_ == 1) {
return; // The first time this function is called, return immediately
// because there is no previous minibatch to train on.
RnnlmExampleDerived derived;
CuArray<int32> active_words_cuda;
CuSparseMatrix<BaseFloat> active_word_features;
CuSparseMatrix<BaseFloat> active_word_features_trans;

if (!current_minibatch_.sampled_words.empty()) {
std::vector<int32> active_words;
RenumberRnnlmExample(&current_minibatch_, &active_words);
active_words_cuda.CopyFromVec(active_words);

if (word_feature_mat_ != NULL) {
active_word_features.SelectRows(active_words_cuda,
*word_feature_mat_);
active_word_features_trans.CopyFromSmat(active_word_features,
kTrans);
}
}
previous_minibatch_full_.Wait();
GetRnnlmExampleDerived(current_minibatch_, train_embedding_,
&derived);

derived_.Swap(&derived);
active_words_.Swap(&active_words_cuda);
active_word_features_.Swap(&active_word_features);
active_word_features_trans_.Swap(&active_word_features_trans);

TrainInternal();
previous_minibatch_empty_.Signal();
}


void RnnlmTrainer::GetWordEmbedding(CuMatrix<BaseFloat> *word_embedding_storage,
CuMatrix<BaseFloat> **word_embedding) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -148,7 +153,7 @@ void RnnlmTrainer::GetWordEmbedding(CuMatrix<BaseFloat> *word_embedding_storage,

void RnnlmTrainer::TrainWordEmbedding(
CuMatrixBase<BaseFloat> *word_embedding_deriv) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -186,7 +191,7 @@ void RnnlmTrainer::TrainWordEmbedding(
void RnnlmTrainer::TrainBackstitchWordEmbedding(
bool is_backstitch_step1,
CuMatrixBase<BaseFloat> *word_embedding_deriv) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -239,21 +244,21 @@ void RnnlmTrainer::TrainInternal() {
srand_seed_ % core_config_.backstitch_training_interval) {
bool is_backstitch_step1 = true;
srand(srand_seed_ + num_minibatches_processed_);
core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_,
core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_,
derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainBackstitchWordEmbedding(is_backstitch_step1, &word_embedding_deriv);

is_backstitch_step1 = false;
srand(srand_seed_ + num_minibatches_processed_);
core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_,
core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_,
derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainBackstitchWordEmbedding(is_backstitch_step1, &word_embedding_deriv);
} else {
core_trainer_->Train(previous_minibatch_, derived_, *word_embedding,
core_trainer_->Train(current_minibatch_, derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainWordEmbedding(&word_embedding_deriv);
Expand All @@ -265,61 +270,7 @@ int32 RnnlmTrainer::VocabSize() {
else return embedding_mat_->NumRows();
}

void RnnlmTrainer::RunBackgroundThread() {
while (true) {
current_minibatch_full_.Wait();
if (end_of_input_)
return;
RnnlmExampleDerived derived;
CuArray<int32> active_words_cuda;
CuSparseMatrix<BaseFloat> active_word_features;
CuSparseMatrix<BaseFloat> active_word_features_trans;

if (!current_minibatch_.sampled_words.empty()) {
std::vector<int32> active_words;
RenumberRnnlmExample(&current_minibatch_, &active_words);
active_words_cuda.CopyFromVec(active_words);

if (word_feature_mat_ != NULL) {
active_word_features.SelectRows(active_words_cuda,
*word_feature_mat_);
active_word_features_trans.CopyFromSmat(active_word_features,
kTrans);
}
}
GetRnnlmExampleDerived(current_minibatch_, train_embedding_,
&derived);

// Wait until the main thread is not currently processing
// previous_minibatch_; once we get this semaphore we are free to write to
// it and other related variables such as 'derived_'.
previous_minibatch_empty_.Wait();
previous_minibatch_.Swap(&current_minibatch_);
derived_.Swap(&derived);
active_words_.Swap(&active_words_cuda);
active_word_features_.Swap(&active_word_features);
active_word_features_trans_.Swap(&active_word_features_trans);

// The following statement signals that 'previous_minibatch_'
// and related variables have been written to by this thread.
previous_minibatch_full_.Signal();
// The following statement signals that 'current_minibatch_'
// has been consumed by this thread and is no longer needed.
current_minibatch_empty_.Signal();
}
}

RnnlmTrainer::~RnnlmTrainer() {
// Train on the last minibatch, because Train() always trains on the previously
// provided one (for threading reasons).
if (num_minibatches_processed_ > 0) {
previous_minibatch_full_.Wait();
TrainInternal();
}
end_of_input_ = true;
current_minibatch_full_.Signal();
background_thread_.join();

// Note: the following delete statements may cause some diagnostics to be
// issued, from the destructors of those classes.
if (core_trainer_)
Expand Down
78 changes: 5 additions & 73 deletions src/rnnlm/rnnlm-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#ifndef KALDI_RNNLM_RNNLM_TRAINING_H_
#define KALDI_RNNLM_RNNLM_TRAINING_H_

#include <thread>
#include "rnnlm/rnnlm-core-training.h"
#include "rnnlm/rnnlm-embedding-training.h"
#include "rnnlm/rnnlm-utils.h"
Expand Down Expand Up @@ -79,10 +78,7 @@ class RnnlmTrainer {


// Train on one example. The example is provided as a pointer because we
// acquire it destructively, via Swap(). Note: this function doesn't
// actually train on this eg; what it does is to train on the previous
// example, and provide this eg to the background thread that computes the
// derived parameters of the eg.
// acquire it destructively, via Swap().
void Train(RnnlmExample *minibatch);


Expand Down Expand Up @@ -129,16 +125,6 @@ class RnnlmTrainer {
bool is_backstitch_step1,
CuMatrixBase<BaseFloat> *word_embedding_deriv);

/// This is the function-call that's run as the background thread which
/// computes the derived parameters for each minibatch.
void RunBackgroundThread();

/// This function is invoked by the newly created background thread.
static void run_background_thread(RnnlmTrainer *trainer) {
trainer->RunBackgroundThread();
}


bool train_embedding_; // true if we are training the embedding.
const RnnlmCoreTrainerOptions &core_config_;
const RnnlmEmbeddingTrainerOptions &embedding_config_;
Expand Down Expand Up @@ -173,32 +159,14 @@ class RnnlmTrainer {
// it's needed.
CuSparseMatrix<BaseFloat> word_feature_mat_transpose_;


// num_minibatches_processed_ starts at zero is incremented each time we
// provide an example to the background thread for computing the derived
// parameters.
int32 num_minibatches_processed_;

// 'current_minibatch' is where the Train() function puts the minibatch that
// is provided to Train(), so that the background thread can work on it.
RnnlmExample current_minibatch_;
// View 'end_of_input_' as part of a unit with current_minibatch_, for threading/access
// purposes. It is set by the foreground thread from the destructor, while
// incrementing the current_minibatch_ready_ semaphore; and when the background
// thread decrements the semaphore and notices that end_of_input_ is true, it will
// exit.
bool end_of_input_;


// previous_minibatch_ is the previous minibatch that was provided to Train(),
// but the minibatch that we're currently trainig on.
RnnlmExample previous_minibatch_;
// The variables derived_ and active_words_ [and more that I'll add, TODO] are in the same
// group as previous_minibatch_ from the point of view
// of threading and access control.
RnnlmExampleDerived derived_;

// The variables derived_ and active_words_ corresponds to group as current_minibatch_.
RnnlmExampleDerived derived_;
// Only if we are doing subsampling (depends on the eg), active_words_
// contains the list of active words for the minibatch 'previous_minibatch_';
// contains the list of active words for the minibatch 'current_minibatch_';
// it is a CUDA version of the 'active_words' output by
// RenumberRnnlmExample(). Otherwise it is empty.
CuArray<int32> active_words_;
Expand All @@ -212,42 +180,6 @@ class RnnlmTrainer {
// This is a derived quantity computed by the background thread.
CuSparseMatrix<BaseFloat> active_word_features_trans_;


// The 'previous_minibatch_full_' semaphore is incremented by the background
// thread once it has written to 'previous_minibatch_' and
// 'derived_previous_', to let the Train() function know that they are ready
// to be trained on. The Train() function waits on this semaphore.
Semaphore previous_minibatch_full_;

// The 'previous_minibatch_empty_' semaphore is incremented by the foreground
// thread when it has done processing previous_minibatch_ and
// derived_ and active_words_ (and hence, it is safe for the background thread to write
// to these variables). The background thread waits on this semaphore once it
// has finished computing the derived variables; and when it successfully
// decrements it, it will write to those variables (quickly, via Swap()).
Semaphore previous_minibatch_empty_;


// The 'current_minibatch_ready_' semaphore is incremented by the foreground
// thread from Train(), when it has written the just-provided minibatch to
// 'current_minibatch_' (it's also incremented by the destructor, together
// with setting end_of_input_. The background thread waits on this semaphore
// before either processing previous_minibatch (if !end_of_input_), or exiting
// (if end_of_input_).
Semaphore current_minibatch_full_;

// The 'current_minibatch_empty_' semaphore is incremented by the background
// thread when it has done processing current_minibatch_,
// so, it is safe for the foreground thread to write
// to this variable). The foreground thread waits on this semaphore before
// writing to 'current_minibatch_' (in practice it should get the semaphore
// immediately since we expect that the foreground thread will have more to
// do than the background thread).
Semaphore current_minibatch_empty_;

std::thread background_thread_; // Background thread for computing 'derived'
// parameters of a minibatch.

// This value is used in backstitch training when we need to ensure
// consistent dropout masks. It's set to a value derived from rand()
// when the class is initialized.
Expand Down

0 comments on commit 8ce3a95

Please sign in to comment.