diff --git a/src/rnnlm/rnnlm-training.cc b/src/rnnlm/rnnlm-training.cc index 959906be2f2..370f6395dc0 100644 --- a/src/rnnlm/rnnlm-training.cc +++ b/src/rnnlm/rnnlm-training.cc @@ -42,9 +42,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding, embedding_trainer_(NULL), word_feature_mat_(word_feature_mat), num_minibatches_processed_(0), - end_of_input_(false), - previous_minibatch_empty_(1), - current_minibatch_empty_(1), srand_seed_(RandInt(0, 100000)) { @@ -75,13 +72,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding, << embedding_mat_->NumRows() << " (mismatch)."; } } - - // Start a thread that calls run_background_thread(this). - // That thread will be responsible for computing derived variables of - // the minibatch, since that can be done independently of the main - // training process. - background_thread_ = std::thread(run_background_thread, this); - } @@ -92,25 +82,40 @@ void RnnlmTrainer::Train(RnnlmExample *minibatch) { << VocabSize() << ", got " << minibatch->vocab_size; - // hand over 'minibatch' to the background thread to have its derived variable - // computed, via the class variable 'current_minibatch_'. - current_minibatch_empty_.Wait(); current_minibatch_.Swap(minibatch); - current_minibatch_full_.Signal(); num_minibatches_processed_++; - if (num_minibatches_processed_ == 1) { - return; // The first time this function is called, return immediately - // because there is no previous minibatch to train on. + RnnlmExampleDerived derived; + CuArray active_words_cuda; + CuSparseMatrix active_word_features; + CuSparseMatrix active_word_features_trans; + + if (!current_minibatch_.sampled_words.empty()) { + std::vector active_words; + RenumberRnnlmExample(¤t_minibatch_, &active_words); + active_words_cuda.CopyFromVec(active_words); + + if (word_feature_mat_ != NULL) { + active_word_features.SelectRows(active_words_cuda, + *word_feature_mat_); + active_word_features_trans.CopyFromSmat(active_word_features, + kTrans); + } } - previous_minibatch_full_.Wait(); + GetRnnlmExampleDerived(current_minibatch_, train_embedding_, + &derived); + + derived_.Swap(&derived); + active_words_.Swap(&active_words_cuda); + active_word_features_.Swap(&active_word_features); + active_word_features_trans_.Swap(&active_word_features_trans); + TrainInternal(); - previous_minibatch_empty_.Signal(); } void RnnlmTrainer::GetWordEmbedding(CuMatrix *word_embedding_storage, CuMatrix **word_embedding) { - RnnlmExample &minibatch = previous_minibatch_; + RnnlmExample &minibatch = current_minibatch_; bool sampling = !minibatch.sampled_words.empty(); if (word_feature_mat_ == NULL) { @@ -148,7 +153,7 @@ void RnnlmTrainer::GetWordEmbedding(CuMatrix *word_embedding_storage, void RnnlmTrainer::TrainWordEmbedding( CuMatrixBase *word_embedding_deriv) { - RnnlmExample &minibatch = previous_minibatch_; + RnnlmExample &minibatch = current_minibatch_; bool sampling = !minibatch.sampled_words.empty(); if (word_feature_mat_ == NULL) { @@ -186,7 +191,7 @@ void RnnlmTrainer::TrainWordEmbedding( void RnnlmTrainer::TrainBackstitchWordEmbedding( bool is_backstitch_step1, CuMatrixBase *word_embedding_deriv) { - RnnlmExample &minibatch = previous_minibatch_; + RnnlmExample &minibatch = current_minibatch_; bool sampling = !minibatch.sampled_words.empty(); if (word_feature_mat_ == NULL) { @@ -239,7 +244,7 @@ void RnnlmTrainer::TrainInternal() { srand_seed_ % core_config_.backstitch_training_interval) { bool is_backstitch_step1 = true; srand(srand_seed_ + num_minibatches_processed_); - core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_, + core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_, derived_, *word_embedding, (train_embedding_ ? &word_embedding_deriv : NULL)); if (train_embedding_) @@ -247,13 +252,13 @@ void RnnlmTrainer::TrainInternal() { is_backstitch_step1 = false; srand(srand_seed_ + num_minibatches_processed_); - core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_, + core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_, derived_, *word_embedding, (train_embedding_ ? &word_embedding_deriv : NULL)); if (train_embedding_) TrainBackstitchWordEmbedding(is_backstitch_step1, &word_embedding_deriv); } else { - core_trainer_->Train(previous_minibatch_, derived_, *word_embedding, + core_trainer_->Train(current_minibatch_, derived_, *word_embedding, (train_embedding_ ? &word_embedding_deriv : NULL)); if (train_embedding_) TrainWordEmbedding(&word_embedding_deriv); @@ -265,61 +270,7 @@ int32 RnnlmTrainer::VocabSize() { else return embedding_mat_->NumRows(); } -void RnnlmTrainer::RunBackgroundThread() { - while (true) { - current_minibatch_full_.Wait(); - if (end_of_input_) - return; - RnnlmExampleDerived derived; - CuArray active_words_cuda; - CuSparseMatrix active_word_features; - CuSparseMatrix active_word_features_trans; - - if (!current_minibatch_.sampled_words.empty()) { - std::vector active_words; - RenumberRnnlmExample(¤t_minibatch_, &active_words); - active_words_cuda.CopyFromVec(active_words); - - if (word_feature_mat_ != NULL) { - active_word_features.SelectRows(active_words_cuda, - *word_feature_mat_); - active_word_features_trans.CopyFromSmat(active_word_features, - kTrans); - } - } - GetRnnlmExampleDerived(current_minibatch_, train_embedding_, - &derived); - - // Wait until the main thread is not currently processing - // previous_minibatch_; once we get this semaphore we are free to write to - // it and other related variables such as 'derived_'. - previous_minibatch_empty_.Wait(); - previous_minibatch_.Swap(¤t_minibatch_); - derived_.Swap(&derived); - active_words_.Swap(&active_words_cuda); - active_word_features_.Swap(&active_word_features); - active_word_features_trans_.Swap(&active_word_features_trans); - - // The following statement signals that 'previous_minibatch_' - // and related variables have been written to by this thread. - previous_minibatch_full_.Signal(); - // The following statement signals that 'current_minibatch_' - // has been consumed by this thread and is no longer needed. - current_minibatch_empty_.Signal(); - } -} - RnnlmTrainer::~RnnlmTrainer() { - // Train on the last minibatch, because Train() always trains on the previously - // provided one (for threading reasons). - if (num_minibatches_processed_ > 0) { - previous_minibatch_full_.Wait(); - TrainInternal(); - } - end_of_input_ = true; - current_minibatch_full_.Signal(); - background_thread_.join(); - // Note: the following delete statements may cause some diagnostics to be // issued, from the destructors of those classes. if (core_trainer_) diff --git a/src/rnnlm/rnnlm-training.h b/src/rnnlm/rnnlm-training.h index e1eec79a3ff..d0a9a1a32e4 100644 --- a/src/rnnlm/rnnlm-training.h +++ b/src/rnnlm/rnnlm-training.h @@ -20,7 +20,6 @@ #ifndef KALDI_RNNLM_RNNLM_TRAINING_H_ #define KALDI_RNNLM_RNNLM_TRAINING_H_ -#include #include "rnnlm/rnnlm-core-training.h" #include "rnnlm/rnnlm-embedding-training.h" #include "rnnlm/rnnlm-utils.h" @@ -79,10 +78,7 @@ class RnnlmTrainer { // Train on one example. The example is provided as a pointer because we - // acquire it destructively, via Swap(). Note: this function doesn't - // actually train on this eg; what it does is to train on the previous - // example, and provide this eg to the background thread that computes the - // derived parameters of the eg. + // acquire it destructively, via Swap(). void Train(RnnlmExample *minibatch); @@ -129,16 +125,6 @@ class RnnlmTrainer { bool is_backstitch_step1, CuMatrixBase *word_embedding_deriv); - /// This is the function-call that's run as the background thread which - /// computes the derived parameters for each minibatch. - void RunBackgroundThread(); - - /// This function is invoked by the newly created background thread. - static void run_background_thread(RnnlmTrainer *trainer) { - trainer->RunBackgroundThread(); - } - - bool train_embedding_; // true if we are training the embedding. const RnnlmCoreTrainerOptions &core_config_; const RnnlmEmbeddingTrainerOptions &embedding_config_; @@ -173,32 +159,14 @@ class RnnlmTrainer { // it's needed. CuSparseMatrix word_feature_mat_transpose_; - - // num_minibatches_processed_ starts at zero is incremented each time we - // provide an example to the background thread for computing the derived - // parameters. int32 num_minibatches_processed_; - // 'current_minibatch' is where the Train() function puts the minibatch that - // is provided to Train(), so that the background thread can work on it. RnnlmExample current_minibatch_; - // View 'end_of_input_' as part of a unit with current_minibatch_, for threading/access - // purposes. It is set by the foreground thread from the destructor, while - // incrementing the current_minibatch_ready_ semaphore; and when the background - // thread decrements the semaphore and notices that end_of_input_ is true, it will - // exit. - bool end_of_input_; - - - // previous_minibatch_ is the previous minibatch that was provided to Train(), - // but the minibatch that we're currently trainig on. - RnnlmExample previous_minibatch_; - // The variables derived_ and active_words_ [and more that I'll add, TODO] are in the same - // group as previous_minibatch_ from the point of view - // of threading and access control. - RnnlmExampleDerived derived_; + + // The variables derived_ and active_words_ corresponds to group as current_minibatch_. + RnnlmExampleDerived derived_; // Only if we are doing subsampling (depends on the eg), active_words_ - // contains the list of active words for the minibatch 'previous_minibatch_'; + // contains the list of active words for the minibatch 'current_minibatch_'; // it is a CUDA version of the 'active_words' output by // RenumberRnnlmExample(). Otherwise it is empty. CuArray active_words_; @@ -212,42 +180,6 @@ class RnnlmTrainer { // This is a derived quantity computed by the background thread. CuSparseMatrix active_word_features_trans_; - - // The 'previous_minibatch_full_' semaphore is incremented by the background - // thread once it has written to 'previous_minibatch_' and - // 'derived_previous_', to let the Train() function know that they are ready - // to be trained on. The Train() function waits on this semaphore. - Semaphore previous_minibatch_full_; - - // The 'previous_minibatch_empty_' semaphore is incremented by the foreground - // thread when it has done processing previous_minibatch_ and - // derived_ and active_words_ (and hence, it is safe for the background thread to write - // to these variables). The background thread waits on this semaphore once it - // has finished computing the derived variables; and when it successfully - // decrements it, it will write to those variables (quickly, via Swap()). - Semaphore previous_minibatch_empty_; - - - // The 'current_minibatch_ready_' semaphore is incremented by the foreground - // thread from Train(), when it has written the just-provided minibatch to - // 'current_minibatch_' (it's also incremented by the destructor, together - // with setting end_of_input_. The background thread waits on this semaphore - // before either processing previous_minibatch (if !end_of_input_), or exiting - // (if end_of_input_). - Semaphore current_minibatch_full_; - - // The 'current_minibatch_empty_' semaphore is incremented by the background - // thread when it has done processing current_minibatch_, - // so, it is safe for the foreground thread to write - // to this variable). The foreground thread waits on this semaphore before - // writing to 'current_minibatch_' (in practice it should get the semaphore - // immediately since we expect that the foreground thread will have more to - // do than the background thread). - Semaphore current_minibatch_empty_; - - std::thread background_thread_; // Background thread for computing 'derived' - // parameters of a minibatch. - // This value is used in backstitch training when we need to ensure // consistent dropout masks. It's set to a value derived from rand() // when the class is initialized.