From 954e69a4f6578072df5be442d5a1b693b8224c24 Mon Sep 17 00:00:00 2001 From: "Jan \"yenda\" Trmal" Date: Mon, 7 May 2018 14:17:31 -0400 Subject: [PATCH] [src] make e2e/"unconstrained" numerator computation faster (#2392) --- COPYING | 12 +- src/base/kaldi-math.h | 2 + src/chain/chain-generic-numerator.cc | 479 +++++++++++++----------- src/chain/chain-generic-numerator.h | 144 ++++--- src/chain/chain-training.cc | 29 +- src/nnet3/nnet-optimize-utils.cc | 19 +- src/nnet3bin/nnet3-egs-augment-image.cc | 16 +- 7 files changed, 388 insertions(+), 313 deletions(-) diff --git a/COPYING b/COPYING index d8804be572c..5a5cab00a29 100644 --- a/COPYING +++ b/COPYING @@ -56,7 +56,7 @@ contributors and original source material as well as the full text of the Apache License v 2.0 are set forth below. Individual Contributors (in alphabetical order) - + Mohit Agarwal Tanel Alumae Gilles Boulianne @@ -123,7 +123,7 @@ Individual Contributors (in alphabetical order) Haihua Xu Hainan Xu Xiaohui Zhang - + Other Source Material This project includes a port and modification of materials from JAMA: A Java @@ -136,9 +136,9 @@ Other Source Material "Signal processing with lapped transforms," Artech House, Inc., 1992. The current copyright holder, Henrique S. Malvar, has given his permission for the release of this modified version under the Apache License 2.0. - - This project includes material from the OpenFST Library v1.2.7 available at - http://www.openfst.org and released under the Apache License v. 2.0. + + This project includes material from the OpenFST Library v1.2.7 available at + http://www.openfst.org and released under the Apache License v. 2.0. [OpenFst COPYING file begins here] @@ -147,7 +147,7 @@ Other Source Material You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/src/base/kaldi-math.h b/src/base/kaldi-math.h index afddc5105d4..21665ddfc63 100644 --- a/src/base/kaldi-math.h +++ b/src/base/kaldi-math.h @@ -183,6 +183,7 @@ inline Float RandPrune(Float post, BaseFloat prune_thresh, inline double LogAdd(double x, double y) { double diff; + if (x < y) { diff = x - y; x = y; @@ -203,6 +204,7 @@ inline double LogAdd(double x, double y) { inline float LogAdd(float x, float y) { float diff; + if (x < y) { diff = x - y; x = y; diff --git a/src/chain/chain-generic-numerator.cc b/src/chain/chain-generic-numerator.cc index 392f7b02def..a228acf8102 100644 --- a/src/chain/chain-generic-numerator.cc +++ b/src/chain/chain-generic-numerator.cc @@ -21,6 +21,10 @@ #include "chain/chain-generic-numerator.h" #include "chain/chain-kernels-ansi.h" +#include +#include +#include + namespace kaldi { namespace chain { @@ -33,59 +37,44 @@ GenericNumeratorComputation::GenericNumeratorComputation( const Supervision &supervision, const CuMatrixBase &nnet_output): supervision_(supervision), - nnet_output_deriv_transposed_( - nnet_output.NumCols(), - std::min(nnet_output.NumRows(), - static_cast(kMaxDerivTimeSteps) * - supervision.num_sequences)), - tot_prob_(supervision.num_sequences, kUndefined), - ok_(true) { + nnet_output_(nnet_output) { KALDI_ASSERT(supervision.num_sequences * supervision.frames_per_sequence == nnet_output.NumRows() && supervision.label_dim == nnet_output.NumCols()); - { - CuMatrix exp_nnet_output_transposed_gpu(nnet_output, kTrans); - exp_nnet_output_transposed_gpu.ApplyExp(); - exp_nnet_output_transposed_.Resize(nnet_output.NumCols(), - nnet_output.NumRows(), kUndefined); - exp_nnet_output_transposed_.CopyFromMat(exp_nnet_output_transposed_gpu); - } using std::vector; - int32 B = supervision_.num_sequences, - num_frames = supervision_.frames_per_sequence; - KALDI_ASSERT(supervision_.e2e_fsts.size() == B); + int num_sequences = supervision_.num_sequences; + KALDI_ASSERT(supervision_.e2e_fsts.size() == num_sequences); // Find the maximum number of HMM states and then // initialize final probs, alpha, and beta. - max_num_hmm_states_ = 0; - for (int32 i = 0; i < B; i++) { + int max_num_hmm_states = 0; + for (int i = 0; i < num_sequences; i++) { KALDI_ASSERT(supervision_.e2e_fsts[i].Properties(fst::kIEpsilons, true) == 0); - if (supervision_.e2e_fsts[i].NumStates() > max_num_hmm_states_) - max_num_hmm_states_ = supervision_.e2e_fsts[i].NumStates(); + if (supervision_.e2e_fsts[i].NumStates() > max_num_hmm_states) + max_num_hmm_states = supervision_.e2e_fsts[i].NumStates(); } - final_probs_.Resize(max_num_hmm_states_, B, kSetZero); - alpha_.Resize(num_frames + 1, - max_num_hmm_states_ * B + B, - kSetZero); - // The extra B is for storing alpha sums - beta_.Resize(2, max_num_hmm_states_ * B, kSetZero); + final_probs_.Resize(num_sequences, max_num_hmm_states); // Initialize incoming transitions for easy access - in_transitions_.resize(B); // indexed by seq, state - out_transitions_.resize(B); // indexed by seq, state - for (int32 seq = 0; seq < B; seq++) { + in_transitions_.resize(num_sequences); // indexed by seq, state + out_transitions_.resize(num_sequences); // indexed by seq, state + for (int seq = 0; seq < num_sequences; seq++) { in_transitions_[seq] = vector >( supervision_.e2e_fsts[seq].NumStates()); out_transitions_[seq] = vector >( supervision_.e2e_fsts[seq].NumStates()); } - offsets_.Resize(B); - for (int32 seq = 0; seq < B; seq++) { + offsets_.Resize(num_sequences); + std::unordered_map pdf_to_index; + int32 pdf_stride = nnet_output_.Stride(); + int32 view_stride = nnet_output_.Stride() * num_sequences; + nnet_output_stride_ = pdf_stride; + for (int seq = 0; seq < num_sequences; seq++) { for (int32 s = 0; s < supervision_.e2e_fsts[seq].NumStates(); s++) { - final_probs_(s, seq) = exp(-supervision_.e2e_fsts[seq].Final(s).Value()); + final_probs_(seq, s)= -supervision_.e2e_fsts[seq].Final(s).Value(); BaseFloat offset = 0.0; if (s == 0) { for (fst::ArcIterator aiter( @@ -98,13 +87,25 @@ GenericNumeratorComputation::GenericNumeratorComputation( } for (fst::ArcIterator aiter( - supervision_.e2e_fsts[seq], s); + supervision_.e2e_fsts[seq], s); !aiter.Done(); aiter.Next()) { const fst::StdArc &arc = aiter.Value(); DenominatorGraphTransition transition; - transition.transition_prob = exp(-(arc.weight.Value() - offset)); - transition.pdf_id = arc.ilabel - 1; + transition.transition_prob = -(arc.weight.Value() - offset); + + int32 pdf_id = arc.ilabel - 1; // note: the FST labels were pdf-id plus one. + + // remap to a unique index in the remapped space + pdf_id = pdf_id + seq * pdf_stride; + KALDI_ASSERT(pdf_id < view_stride); + + if (pdf_to_index.find(pdf_id) == pdf_to_index.end()) { + index_to_pdf_.push_back(pdf_id); + pdf_to_index[pdf_id] = index_to_pdf_.size() - 1; + } + + transition.pdf_id = pdf_to_index[pdf_id]; transition.hmm_state = s; in_transitions_[seq][arc.nextstate].push_back(transition); transition.hmm_state = arc.nextstate; @@ -115,229 +116,273 @@ GenericNumeratorComputation::GenericNumeratorComputation( } -void GenericNumeratorComputation::AlphaFirstFrame() { - const int32 num_sequences = supervision_.num_sequences, - num_states = max_num_hmm_states_; - // Set alpha_0(0) for all sequences to 1.0 and leave the rest to be 0.0. - double *first_frame_alpha = alpha_.RowData(0); - SubVector alpha_hmm_state0(first_frame_alpha, num_sequences); - alpha_hmm_state0.Set(1.0); - - // Now compute alpha-sums for t=0 which is obviously 1.0 for each sequence - SubVector alpha_sum_vec(first_frame_alpha + - num_states * num_sequences, - num_sequences); - alpha_sum_vec.Set(1.0); +void GenericNumeratorComputation::AlphaFirstFrame(int seq, + Matrix *alpha) { + const int32 num_frames = supervision_.frames_per_sequence, + num_states = supervision_.e2e_fsts[seq].NumStates(); + alpha->Resize(num_frames + 1, num_states + 1, kSetZero); + alpha->Set(-std::numeric_limits::infinity()); + (*alpha)(0, 0) = 0.0; + (*alpha)(0, num_states) = 0.0; } +void GenericNumeratorComputation::CopySpecificPdfsIndirect( + const CuMatrixBase &nnet_output, + const std::vector &indices, + Matrix *out) { + KALDI_ASSERT(nnet_output_stride_ == nnet_output_.Stride()); + const int32 num_sequences = supervision_.num_sequences, + frames_per_sequence = supervision_.frames_per_sequence; + + const BaseFloat *starting_ptr = nnet_output.RowData(0); + const int view_stride = num_sequences * nnet_output.Stride(); + + const CuSubMatrix sequence_view(starting_ptr, + frames_per_sequence, + view_stride, + view_stride); + + CuArray indices_gpu(indices); + CuMatrix required_pdfs(frames_per_sequence, + indices.size()); + + required_pdfs.CopyCols(sequence_view, indices_gpu); + out->Swap(&required_pdfs); +} + // The alpha computation for some 0 < t <= num_time_steps_. -void GenericNumeratorComputation::AlphaGeneralFrame(int32 t) { +BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq, + const Matrix &probs, + Matrix *alpha) { // Define some variables to make things nicer - const int32 - num_sequences = supervision_.num_sequences, - num_frames = supervision_.frames_per_sequence, - num_pdfs = exp_nnet_output_transposed_.NumRows(), - num_states = max_num_hmm_states_; - KALDI_ASSERT(t > 0 && t <= num_frames); - - SubMatrix this_alpha(alpha_.RowData(t), num_states, - num_sequences, num_sequences); - const SubMatrix prev_alpha(alpha_.RowData(t - 1), num_states + 1, - num_sequences, num_sequences); - // 'probs' is the matrix of pseudo-likelihoods for frame t - 1. - SubMatrix probs(exp_nnet_output_transposed_, 0, num_pdfs, - (t - 1) * num_sequences, num_sequences); - - for (int32 seq = 0; seq < num_sequences; seq++) { - double inv_arbitrary_scale = prev_alpha(num_states, seq); + const int32 num_sequences = supervision_.num_sequences, + num_frames = supervision_.frames_per_sequence; + + KALDI_ASSERT(seq >= 0 && seq < num_sequences); + + SubMatrix alpha_view(*alpha, + 0, alpha->NumRows(), + 0, alpha->NumCols()); + + // variables for log_likelihood computation + double log_scale_product = 0, + log_prob_product = 0; + + for (int t = 1; t <= num_frames; ++t) { + SubMatrix prev_alpha_t(alpha_view, t - 1, 1, 0, + alpha_view.NumCols() - 1); + SubMatrix this_alpha_t(alpha_view, t, 1, 0, + alpha_view.NumCols() - 1); + for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) { for (auto tr = in_transitions_[seq][h].begin(); - tr != in_transitions_[seq][h].end(); tr++) { - double transition_prob = tr->transition_prob; - int32 pdf_id = tr->pdf_id, prev_hmm_state = tr->hmm_state; - double prob = probs(pdf_id, seq); - this_alpha(h, seq) += prev_alpha(prev_hmm_state, seq) / - inv_arbitrary_scale * transition_prob * prob; + tr != in_transitions_[seq][h].end(); tr++) { + BaseFloat transition_prob = tr->transition_prob; + int32 pdf_id = tr->pdf_id, + prev_hmm_state = tr->hmm_state; + BaseFloat prob = probs(t-1, pdf_id); + alpha_view(t, h) = LogAdd(alpha_view(t, h), + alpha_view(t-1, prev_hmm_state) + transition_prob + prob); } } - } + double sum = alpha_view(t-1, alpha_view.NumCols() - 1); + this_alpha_t.Add(-sum); + sum = this_alpha_t.LogSumExp(); - if (t == num_frames) // last alpha - this_alpha.MulElements(final_probs_); - // Now compute alpha-sums for frame t: - SubVector alpha_sum_vec(alpha_.RowData(t) + num_states * num_sequences, - num_sequences); - alpha_sum_vec.AddRowSumMat(1.0, this_alpha, 0.0); + alpha_view(t, alpha_view.NumCols() - 1) = sum; + log_scale_product += sum; + } + SubMatrix last_alpha(alpha_view, alpha_view.NumRows() - 1, 1, + 0, alpha_view.NumCols() - 1); + SubVector final_probs(final_probs_.RowData(seq), + alpha_view.NumCols() - 1); + + // adjust last_alpha + double sum = alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1); + log_scale_product -= sum; + last_alpha.AddVecToRows(1.0, final_probs); + sum = last_alpha.LogSumExp(); + alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1) = sum; + + // second part of criterion + log_prob_product = sum - offsets_(seq); + + return log_prob_product + log_scale_product; } -BaseFloat GenericNumeratorComputation::Forward() { - AlphaFirstFrame(); - for (int32 t = 1; t <= supervision_.frames_per_sequence; t++) { - AlphaGeneralFrame(t); +bool GenericNumeratorComputation::ForwardBackward( + BaseFloat *total_loglike, + CuMatrixBase *nnet_output_deriv) { + KALDI_ASSERT(total_loglike != NULL); + KALDI_ASSERT(nnet_output_deriv != NULL); + KALDI_ASSERT(nnet_output_deriv->NumCols() == nnet_output_.NumCols()); + KALDI_ASSERT(nnet_output_deriv->NumRows() == nnet_output_.NumRows()); + + BaseFloat partial_loglike = 0; + const int32 num_sequences = supervision_.num_sequences; + + bool ok = true; + Matrix alpha; + Matrix beta; + Matrix probs; + Matrix derivs; + + // We selectively copy only those pdfs we need + CopySpecificPdfsIndirect(nnet_output_, index_to_pdf_, &probs); + + derivs.Resize(probs.NumRows(), probs.NumCols()); + derivs.Set(-std::numeric_limits::infinity()); + + for (int seq = 0; seq < num_sequences; ++seq) { + // Forward part + AlphaFirstFrame(seq, &alpha); + partial_loglike += AlphaRemainingFrames(seq, probs, &alpha); + + // Backward part + BetaLastFrame(seq, alpha, &beta); + BetaRemainingFrames(seq, probs, alpha, &beta, &derivs); + ok = ok || CheckValues(seq, probs, alpha, beta, derivs); } - return ComputeTotLogLike(); + // Transfer and add the derivatives to the values in the matrix + AddSpecificPdfsIndirect(&derivs, index_to_pdf_, nnet_output_deriv); + *total_loglike = partial_loglike; + return ok; } -BaseFloat GenericNumeratorComputation::ComputeTotLogLike() { - const int32 - num_sequences = supervision_.num_sequences, - num_frames = supervision_.frames_per_sequence, - num_states = max_num_hmm_states_; - - // View the last alpha as a matrix of size num-hmm-states by num-sequences. - SubMatrix last_alpha(alpha_.RowData(num_frames), num_states, - num_sequences, num_sequences); - tot_prob_.AddRowSumMat(1.0, last_alpha, 0.0); - Vector tot_log_probs(tot_prob_); - tot_log_probs.ApplyLog(); - tot_log_probs.AddVec(-1.0, offsets_); - double tot_log_prob = tot_log_probs.Sum(); - SubMatrix inv_arbitrary_scales(alpha_, 0, num_frames, - num_sequences * num_states, - num_sequences); - Matrix log_inv_arbitrary_scales(inv_arbitrary_scales); - log_inv_arbitrary_scales.ApplyLog(); - double log_inv_arbitrary_scales_product = - log_inv_arbitrary_scales.Sum(); - return tot_log_prob + log_inv_arbitrary_scales_product; -} +BaseFloat GenericNumeratorComputation::ComputeObjf() { + BaseFloat partial_loglike = 0; + const int32 num_sequences = supervision_.num_sequences; + Matrix alpha; + Matrix probs; -bool GenericNumeratorComputation::Backward( - CuMatrixBase *nnet_output_deriv) { - const int32 - num_sequences = supervision_.num_sequences, - num_frames = supervision_.frames_per_sequence, - num_pdfs = exp_nnet_output_transposed_.NumRows(); - BetaLastFrame(); - for (int32 t = num_frames - 1; t >= 0; t--) { - BetaGeneralFrame(t); - if (GetVerboseLevel() >= 1 || t == 0 || t == num_frames - 1) - BetaGeneralFrameDebug(t); - if (t % kMaxDerivTimeSteps == 0) { - // Commit the derivative stored in exp_nnet_output_transposed_ by adding - // its transpose to the appropriate sub-matrix of 'nnet_output_deriv'. - int32 chunk_frames = std::min(static_cast(kMaxDerivTimeSteps), - num_frames - t); - SubMatrix transposed_deriv_part( - nnet_output_deriv_transposed_, - 0, num_pdfs, - 0, chunk_frames * num_sequences); - CuMatrix tmp(transposed_deriv_part); - CuSubMatrix output_deriv_part( - *nnet_output_deriv, - t * num_sequences, chunk_frames * num_sequences, - 0, num_pdfs); - output_deriv_part.AddMat(supervision_.weight, tmp, kTrans); - if (t != 0) - transposed_deriv_part.SetZero(); - } + // We selectively copy only those pdfs we need + CopySpecificPdfsIndirect(nnet_output_, index_to_pdf_, &probs); + + for (int seq = 0; seq < num_sequences; ++seq) { + // Forward part + AlphaFirstFrame(seq, &alpha); + partial_loglike += AlphaRemainingFrames(seq, probs, &alpha); } - return ok_; + return partial_loglike; +} + + +BaseFloat GenericNumeratorComputation::GetTotalProb( + const Matrix &alpha) { + return alpha(alpha.NumRows() - 1, alpha.NumCols() - 1); } -void GenericNumeratorComputation::BetaLastFrame() { +void GenericNumeratorComputation::BetaLastFrame(int seq, + const Matrix &alpha, + Matrix *beta) { // Sets up the beta quantity on the last frame (frame == // frames_per_sequence_). Note that the betas we use here contain a // 1/(tot-prob) factor in order to simplify the backprop. - int32 t = supervision_.frames_per_sequence; - double *last_frame_beta = beta_.RowData(t % 2); + const int32 num_frames = supervision_.frames_per_sequence, + num_states = supervision_.e2e_fsts[seq].NumStates(); + float tot_prob = GetTotalProb(alpha); - SubMatrix beta_mat(last_frame_beta, - max_num_hmm_states_, - supervision_.num_sequences, - supervision_.num_sequences); + beta->Resize(2, num_states); + beta->Set(-std::numeric_limits::infinity()); - Vector inv_tot_prob(tot_prob_); - inv_tot_prob.InvertElements(); + SubVector beta_mat(beta->RowData(num_frames % 2), num_states); + SubVector final_probs(final_probs_.RowData(seq), num_states); - beta_mat.CopyRowsFromVec(inv_tot_prob); - beta_mat.MulElements(final_probs_); + BaseFloat inv_tot_prob = -tot_prob; + beta_mat.Set(inv_tot_prob); + beta_mat.AddVec(1.0, final_probs); } -void GenericNumeratorComputation::BetaGeneralFrame(int32 t) { +void GenericNumeratorComputation::BetaRemainingFrames(int seq, + const Matrix &probs, + const Matrix &alpha, + Matrix *beta, + Matrix *derivs) { const int32 num_sequences = supervision_.num_sequences, num_frames = supervision_.frames_per_sequence, - num_pdfs = exp_nnet_output_transposed_.NumRows(), - num_states = max_num_hmm_states_; - KALDI_ASSERT(t >= 0 && t < num_frames); - - // t_wrapped gives us the time-index we use when indexing - // nnet_output_deriv_transposed_; to save memory we limit the size of the - // matrix, storing only chunks of frames at a time, and we add it to the - // non-transposed output whenever we finish a chunk. - int32 t_wrapped = t % static_cast(kMaxDerivTimeSteps); - const SubMatrix this_alpha(alpha_.RowData(t), num_states, - num_sequences, num_sequences); - SubMatrix this_beta(beta_.RowData(t % 2), num_states, - num_sequences, num_sequences); - const SubMatrix next_beta(beta_.RowData((t + 1) % 2), num_states, - num_sequences, num_sequences); - - SubMatrix probs(exp_nnet_output_transposed_, 0, num_pdfs, - t * num_sequences, num_sequences), - log_prob_deriv(nnet_output_deriv_transposed_, 0, num_pdfs, - t_wrapped * num_sequences, num_sequences); - - for (int32 seq = 0; seq < num_sequences; seq++) { + num_states = supervision_.e2e_fsts[seq].NumStates(); + KALDI_ASSERT(seq >= 0 && seq < num_sequences); + + SubMatrix log_prob_deriv(*derivs, + 0, derivs->NumRows(), + 0, derivs->NumCols()); + + for (int t = num_frames - 1; t >= 0; --t) { + SubVector this_beta(beta->RowData(t % 2), num_states); + const SubVector next_beta(beta->RowData((t + 1) % 2), + num_states); + + BaseFloat inv_arbitrary_scale = alpha(t, num_states); + for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) { - BaseFloat inv_arbitrary_scale = this_alpha(num_states, seq); - double tot_variable_factor = 0.0; + BaseFloat tot_variable_factor; + tot_variable_factor = -std::numeric_limits::infinity(); for (auto tr = out_transitions_[seq][h].begin(); - tr != out_transitions_[seq][h].end(); tr++) { + tr != out_transitions_[seq][h].end(); tr++) { BaseFloat transition_prob = tr->transition_prob; int32 pdf_id = tr->pdf_id, next_hmm_state = tr->hmm_state; - double variable_factor = transition_prob * - next_beta(next_hmm_state, seq) * - probs(pdf_id, seq) / inv_arbitrary_scale; - tot_variable_factor += variable_factor; - double occupation_prob = variable_factor * this_alpha(h, seq); - log_prob_deriv(pdf_id, seq) += occupation_prob; + BaseFloat variable_factor = transition_prob + + next_beta(next_hmm_state) + + probs(t, pdf_id) - inv_arbitrary_scale; + tot_variable_factor = LogAdd(tot_variable_factor, + variable_factor); + + BaseFloat occupation_prob = variable_factor + alpha(t, h); + log_prob_deriv(t, pdf_id) = LogAdd(log_prob_deriv(t, pdf_id), + occupation_prob); } - this_beta(h, seq) = tot_variable_factor; + this_beta(h) = tot_variable_factor; } } } -void GenericNumeratorComputation::BetaGeneralFrameDebug(int32 t) { - int32 alpha_beta_size = max_num_hmm_states_ * supervision_.num_sequences; - SubVector this_alpha(alpha_.RowData(t), alpha_beta_size), - this_beta(beta_.RowData(t % 2), alpha_beta_size); - int32 t_wrapped = t % static_cast(kMaxDerivTimeSteps), - num_pdfs = exp_nnet_output_transposed_.NumRows(); - SubMatrix this_log_prob_deriv( - nnet_output_deriv_transposed_, 0, num_pdfs, - t_wrapped * supervision_.num_sequences, supervision_.num_sequences); - double alpha_beta_product = VecVec(this_alpha, - this_beta), - this_log_prob_deriv_sum = this_log_prob_deriv.Sum(); - if (!ApproxEqual(alpha_beta_product, supervision_.num_sequences)) { - KALDI_WARN << "On time " << t << ", alpha-beta product " - << alpha_beta_product << " != " << supervision_.num_sequences - << " alpha-sum = " << this_alpha.Sum() - << ", beta-sum = " << this_beta.Sum(); - if (fabs(alpha_beta_product - supervision_.num_sequences) > 2.0 - || alpha_beta_product - alpha_beta_product != 0) { - KALDI_WARN << "Excessive error detected, will abandon this minibatch"; - ok_ = false; - } - } - // Use higher tolerance, since we are using randomized pruning for the - // log-prob derivatives. - if (!ApproxEqual(this_log_prob_deriv_sum, - supervision_.num_sequences, 0.01)) { - KALDI_WARN << "On time " << t << ", log-prob-deriv sum " - << this_log_prob_deriv_sum << " != " - << supervision_.num_sequences; - if (fabs(this_log_prob_deriv_sum - supervision_.num_sequences) > 2.0 || - this_log_prob_deriv_sum - this_log_prob_deriv_sum != 0) { - KALDI_WARN << "Excessive error detected, will abandon this minibatch"; - ok_ = false; - } + +void GenericNumeratorComputation::AddSpecificPdfsIndirect( + Matrix *logprobs, + const std::vector &indices, + CuMatrixBase *output) { + const int32 num_sequences = supervision_.num_sequences, + frames_per_sequence = supervision_.frames_per_sequence; + + const int view_stride = output->Stride() * num_sequences; + + KALDI_ASSERT(frames_per_sequence * num_sequences == output->NumRows()); + + CuMatrix specific_pdfs; + specific_pdfs.Swap(logprobs); + specific_pdfs.ApplyExp(); + specific_pdfs.Scale(supervision_.weight); + + std::vector indices_expanded(view_stride, -1); + for (int i = 0; i < indices.size(); ++i) { + int pdf_index = indices[i]; + int sequence_local_pdf_index = pdf_index % nnet_output_stride_; + int sequence_index = pdf_index / nnet_output_stride_; + pdf_index = sequence_local_pdf_index + + sequence_index * output->Stride(); + KALDI_ASSERT(pdf_index < view_stride); + KALDI_ASSERT(i < specific_pdfs.NumCols()); + indices_expanded[pdf_index] = i; } + + CuArray cu_indices(indices_expanded); + CuSubMatrix out(output->Data(), frames_per_sequence, + view_stride, view_stride); + + out.AddCols(specific_pdfs, cu_indices); +} + +bool GenericNumeratorComputation::CheckValues(int seq, + const Matrix &probs, + const Matrix &alpha, + const Matrix &beta, + const Matrix &derivs) const { + // empty checks for now + return true; } } // namespace chain diff --git a/src/chain/chain-generic-numerator.h b/src/chain/chain-generic-numerator.h index e7bcb524a95..9c1a826c6fd 100644 --- a/src/chain/chain-generic-numerator.h +++ b/src/chain/chain-generic-numerator.h @@ -1,6 +1,7 @@ // chain/chain-generic-numerator.h // Copyright 2017 Hossein Hadian +// 2018 Johns Hopkins University (Jan "Yenda" Trmal) // See ../../COPYING for clarification regarding multiple authors @@ -66,10 +67,38 @@ namespace chain { training). It is the same as DenominatorComputation with 2 differences: [1] it runs on CPU [2] it does not use leakyHMM + The F-B computation is done in log-domain. When the 'e2e' flag of a supervision is set, the ComputeChainObjfAndDeriv function in chain-training.cc uses GenericNumeratorComputation (instead of NumeratorCompuation) to compute the numerator derivatives. + + The implementation tries to optimize the memory transfers. The optimization + uses the observation that for each supervision graph, only very limited + number of pdfs is needed to evaluate the possible transitions from state + to state. That means that for the F-B, we don't have to transfer the whole + neural network output, we can copy only the limited set of pdfs activation + values that will be needed for F-B on the given graph. + + To streamline things, in the constructor of this class, we remap the pdfs + indices to a new space and store the bookkeeping info in the index_to_pdf_ + structure. This can be seen as if for each FST we create a subspace that + has only the pdfs that are needed for the given FST (possibly ordered + differently). + + Morover, we optimize memory transfers. The matrix of nnet outputs can be + reshaped (viewed) as a matrix of dimensions + (frames_per_sequence) x (num_sequences * pdf_stride), where the pdf_stride + is the stride of the original matrix and pdf_stride >= num_pdfs. + When the matrix is viewed this way, it becomes obvious that the pdfs of the + k-th supervision sequence have column index k * pdf_stride + original_pdf_index + Once this is understood, the way how copy all pdfs in one shot should become + obvious. + + The complete F-B is then done in this remapped space and only + when copying the activation values from the GPU memory or copying + the computed derivatives to GPU memory, we use the bookkeeping info to + map the values correctly. */ @@ -81,90 +110,95 @@ namespace chain { // and the numerator FSTs are stored in 'e2e_fsts' instead of 'fst' class GenericNumeratorComputation { - public: - /// Initializes the object. GenericNumeratorComputation(const Supervision &supervision, const CuMatrixBase &nnet_output); - // Does the forward computation. Returns the total log-prob multiplied - // by supervision_.weight. - BaseFloat Forward(); - - // Does the backward computation and (efficiently) adds the derivative of the + // Does the forward-backward computation. Returns the total log-prob + // multiplied by supervision_.weight. + // In the backward computation, add (efficiently) the derivative of the // nnet output w.r.t. the (log-prob times supervision_.weight times // deriv_weight) to 'nnet_output_deriv'. - bool Backward(CuMatrixBase *nnet_output_deriv); + bool ForwardBackward(BaseFloat *total_loglike, + CuMatrixBase *nnet_output_deriv); + BaseFloat ComputeObjf(); private: - - // Defining this constant as an enum is easier. it controls a memory/speed - // tradeoff, determining how many frames' worth of the transposed derivative - // we store at a time. It's not very critical; the only disadvantage from - // setting it small is that we have to invoke an AddMat kernel more times. - enum { kMaxDerivTimeSteps = 8 }; + // For the remapped FSTs, copy the appropriate activations to CPU memory. + // For explanation of what remapped FST is, see the large comment in the + // beginning of the file + void CopySpecificPdfsIndirect( + const CuMatrixBase &nnet_output, + const std::vector &indices, + Matrix *output); + + // For the remapped FSTs, copy the computed values back to gpu, + // expand to the original shape and add to the output matrix. + // For explanation of what remapped FST is, see the large comment in the + // beginning of the file. + void AddSpecificPdfsIndirect( + Matrix *logprobs, + const std::vector &indices, + CuMatrixBase *output); // sets up the alpha for frame t = 0. - void AlphaFirstFrame(); - - // the alpha computation for some 0 < t <= num_time_steps_. - void AlphaGeneralFrame(int32 t); - - BaseFloat ComputeTotLogLike(); - - // sets up the beta for frame t = num_time_steps_. - void BetaLastFrame(); - - // the beta computation for 0 <= beta < num_time_steps_. - void BetaGeneralFrame(int32 t); + void AlphaFirstFrame(int seq, Matrix *alpha); + + // the alpha computation for 0 < t <= supervision_.frames_per_sequence + // for some 0 <= seq < supervision_.num_sequences. + BaseFloat AlphaRemainingFrames(int seq, + const Matrix &probs, + Matrix *alpha); + + // the beta computation for 0 <= t < supervision_.frames_per_sequence + // for some 0 <= seq < supervision_.num_sequences. + void BetaRemainingFrames(int32 seq, + const Matrix &probs, + const Matrix &alpha, + Matrix *beta, + Matrix *derivs); + + // the beta computation for t = supervision_.frames_per_sequence + void BetaLastFrame(int seq, + const Matrix &alpha, + Matrix *beta); + + // returns total prob for the given matrix alpha (assumes the alpha + // matrix was computed using AlphaFirstFrame() and AlphaRemainingFrames() + // (it's exactly like 'tot_probe_' in DenominatorComputation) + BaseFloat GetTotalProb(const Matrix &alpha); // some checking that we can do if debug mode is activated, or on frame zero. // Sets ok_ to false if a bad problem is detected. - void BetaGeneralFrameDebug(int32 t); + bool CheckValues(int32 seq, + const Matrix &probs, + const Matrix &alpha, + const Matrix &beta, + const Matrix &derivs) const; const Supervision &supervision_; - // the transposed neural net output. - Matrix exp_nnet_output_transposed_; + // a reference to the nnet output. + const CuMatrixBase &nnet_output_; + int32 nnet_output_stride_; // we keep the original stride extra + // as the matrix can change before ForwardBackward // in_transitions_ lists all the incoming transitions for // each state of each numerator graph // out_transitions_ does the same but for the outgoing transitions - std::vector > > - in_transitions_, out_transitions_; + typedef std::vector > TransitionMap; + std::vector in_transitions_, out_transitions_; + std::vector index_to_pdf_; // final probs for each state of each numerator graph - Matrix final_probs_; // indexed by seq, state + Matrix final_probs_; // indexed by seq, state // an offset subtracted from the logprobs of transitions out of the first // state of each graph to help reduce numerical problems. Note the // generic forward-backward computations cannot be done in log-space. Vector offsets_; - - // maximum number of states among all the numerator graphs - // (it is used as a stride in alpha_ and beta_) - int32 max_num_hmm_states_; - - // the derivs w.r.t. the nnet outputs (transposed) - // (the dimensions and functionality is the same as in - // DenominatorComputation) - Matrix nnet_output_deriv_transposed_; - - // forward and backward probs matrices. These have the - // same dimension and functionality as alpha_ and beta_ - // in DenominatorComputation except here we don't use beta - // sums (becasue we don't use leakyHMM). However, we use - // alpha sums to help avoid numerical issues. - Matrix alpha_; - Matrix beta_; - - // vector of total probs (i.e. for all the sequences) - // (it's exactly like 'tot_probe_' in DenominatorComputation) - Vector tot_prob_; - - bool ok_; }; } // namespace chain diff --git a/src/chain/chain-training.cc b/src/chain/chain-training.cc index 0dd12633c74..6b4a7b593c2 100644 --- a/src/chain/chain-training.cc +++ b/src/chain/chain-training.cc @@ -75,25 +75,22 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts, GenericNumeratorComputation numerator(supervision, nnet_output); // note: supervision.weight is included as a factor in the derivative from // the numerator object, as well as the returned logprob. - num_logprob_weighted = numerator.Forward(); - KALDI_VLOG(2) << "Numerator logprob per frame: " - << num_logprob_weighted / (*weight); - numerator_ok = (num_logprob_weighted - num_logprob_weighted == 0); - if (!numerator_ok) - KALDI_LOG << "Numerator forward failed."; - - if (xent_output_deriv && numerator_ok) { - numerator_ok = numerator.Backward(xent_output_deriv); - if (!numerator_ok) - KALDI_LOG << "Numerator backward failed."; - if (nnet_output_deriv) + if (xent_output_deriv) { + numerator_ok = numerator.ForwardBackward(&num_logprob_weighted, + xent_output_deriv); + if (numerator_ok && nnet_output_deriv) nnet_output_deriv->AddMat(1.0, *xent_output_deriv); - } else if (nnet_output_deriv && numerator_ok) { - numerator_ok = numerator.Backward(nnet_output_deriv); - if (!numerator_ok) - KALDI_LOG << "Numerator backward failed."; + } else if (nnet_output_deriv) { + numerator_ok = numerator.ForwardBackward(&num_logprob_weighted, + nnet_output_deriv); + } else { + num_logprob_weighted = numerator.ComputeObjf(); } + if (!numerator_ok) + KALDI_WARN << "Numerator forward-backward failed."; } + numerator_ok = numerator_ok && + (num_logprob_weighted - num_logprob_weighted == 0); *objf = num_logprob_weighted - den_logprob_weighted; if (!((*objf) - (*objf) == 0) || !denominator_ok || !numerator_ok) { diff --git a/src/nnet3/nnet-optimize-utils.cc b/src/nnet3/nnet-optimize-utils.cc index e587c7ff947..815ce5e9ed0 100644 --- a/src/nnet3/nnet-optimize-utils.cc +++ b/src/nnet3/nnet-optimize-utils.cc @@ -696,18 +696,15 @@ void RenumberComputation(NnetComputation *computation) { } +static bool IsNoop(const NnetComputation::Command &command) { + return command.command_type == kNoOperation; +} + void RemoveNoOps(NnetComputation *computation) { - std::vector::iterator - input_iter = computation->commands.begin(), - input_end = computation->commands.end(), - output_iter = computation->commands.begin(); - for (; input_iter != input_end; ++input_iter) { - if (input_iter->command_type != kNoOperation) { - *output_iter = *input_iter; - ++output_iter; - } - } - computation->commands.resize(output_iter - computation->commands.begin()); + computation->commands.erase( + std::remove_if(computation->commands.begin(), + computation->commands.end(), + IsNoop), computation->commands.end()); } diff --git a/src/nnet3bin/nnet3-egs-augment-image.cc b/src/nnet3bin/nnet3-egs-augment-image.cc index 6020036cc29..ef724d0c6a6 100644 --- a/src/nnet3bin/nnet3-egs-augment-image.cc +++ b/src/nnet3bin/nnet3-egs-augment-image.cc @@ -66,8 +66,8 @@ struct ImageAugmentationConfig { po->Register("rotation-prob", &rotation_prob, "Probability of doing rotation"); po->Register("fill-mode", &fill_mode_string, "Mode for dealing with " - "points outside the image boundary when applying transformation. " - "Choices = {nearest, reflect}"); + "points outside the image boundary when applying transformation. " + "Choices = {nearest, reflect}"); } void Check() const { @@ -87,10 +87,10 @@ struct ImageAugmentationConfig { fill_mode = kReflect; } else { if (fill_mode_string != "nearest") { - KALDI_ERR << "Choices for --fill-mode are 'nearest' or 'reflect', got: " - << fill_mode_string; + KALDI_ERR << "Choices for --fill-mode are 'nearest' or 'reflect', got: " + << fill_mode_string; } else { - fill_mode = kNearest; + fill_mode = kNearest; } } return fill_mode; @@ -243,7 +243,7 @@ void PerturbImage(const ImageAugmentationConfig &config, // 0 0 1 ] if (RandUniform() <= config.rotation_prob) { BaseFloat theta = (2 * config.rotation_degree * RandUniform() - - config.rotation_degree) / 180.0 * M_PI; + config.rotation_degree) / 180.0 * M_PI; rotation_mat(0, 0) = cos(theta); rotation_mat(0, 1) = -sin(theta); rotation_mat(1, 0) = sin(theta); @@ -325,8 +325,8 @@ void PerturbImageInNnetExample( } -} // namespace nnet3 -} // namespace kaldi +} // namespace nnet3 +} // namespace kaldi int main(int argc, char *argv[]) { try {