Skip to content

Commit

Permalink
[src] Minor optimizations in "e2e" numerator code (kaldi-asr#2508)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhadian authored and danpovey committed Jun 19, 2018
1 parent 598b177 commit 775c770
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 43 deletions.
97 changes: 57 additions & 40 deletions src/chain/chain-generic-numerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ GenericNumeratorComputation::GenericNumeratorComputation(
}

offsets_.Resize(num_sequences);
std::unordered_map<int, MatrixIndexT> pdf_to_index;
std::unordered_map<int32, MatrixIndexT> pdf_to_index;
int32 pdf_stride = nnet_output_.Stride();
int32 view_stride = nnet_output_.Stride() * num_sequences;
pdf_to_index.reserve(view_stride);
nnet_output_stride_ = pdf_stride;
for (int seq = 0; seq < num_sequences; seq++) {
for (int32 s = 0; s < supervision_.e2e_fsts[seq].NumStates(); s++) {
Expand Down Expand Up @@ -161,49 +162,46 @@ BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,

KALDI_ASSERT(seq >= 0 && seq < num_sequences);

SubMatrix<BaseFloat> alpha_view(*alpha,
0, alpha->NumRows(),
0, alpha->NumCols());

// variables for log_likelihood computation
double log_scale_product = 0,
log_prob_product = 0;

for (int t = 1; t <= num_frames; ++t) {
SubMatrix<BaseFloat> prev_alpha_t(alpha_view, t - 1, 1, 0,
alpha_view.NumCols() - 1);
SubMatrix<BaseFloat> this_alpha_t(alpha_view, t, 1, 0,
alpha_view.NumCols() - 1);
const BaseFloat *probs_tm1 = probs.RowData(t - 1);
BaseFloat *alpha_t = alpha->RowData(t);
const BaseFloat *alpha_tm1 = alpha->RowData(t - 1);

for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) {
for (auto tr = in_transitions_[seq][h].begin();
tr != in_transitions_[seq][h].end(); tr++) {
tr != in_transitions_[seq][h].end(); ++tr) {
BaseFloat transition_prob = tr->transition_prob;
int32 pdf_id = tr->pdf_id,
prev_hmm_state = tr->hmm_state;
BaseFloat prob = probs(t-1, pdf_id);
alpha_view(t, h) = LogAdd(alpha_view(t, h),
alpha_view(t-1, prev_hmm_state) + transition_prob + prob);
BaseFloat prob = probs_tm1[pdf_id];
alpha_t[h] = LogAdd(alpha_t[h],
alpha_tm1[prev_hmm_state] + transition_prob + prob);
}
}
double sum = alpha_view(t-1, alpha_view.NumCols() - 1);
this_alpha_t.Add(-sum);
sum = this_alpha_t.LogSumExp();
double sum = alpha_tm1[alpha->NumCols() - 1];
SubMatrix<BaseFloat> alpha_t_mat(*alpha, t, 1, 0,
alpha->NumCols() - 1);
alpha_t_mat.Add(-sum);
sum = alpha_t_mat.LogSumExp();

alpha_view(t, alpha_view.NumCols() - 1) = sum;
alpha_t[alpha->NumCols() - 1] = sum;
log_scale_product += sum;
}
SubMatrix<BaseFloat> last_alpha(alpha_view, alpha_view.NumRows() - 1, 1,
0, alpha_view.NumCols() - 1);
SubMatrix<BaseFloat> last_alpha(*alpha, alpha->NumRows() - 1, 1,
0, alpha->NumCols() - 1);
SubVector<BaseFloat> final_probs(final_probs_.RowData(seq),
alpha_view.NumCols() - 1);
alpha->NumCols() - 1);

// adjust last_alpha
double sum = alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1);
double sum = (*alpha)(alpha->NumRows() - 1, alpha->NumCols() - 1);
log_scale_product -= sum;
last_alpha.AddVecToRows(1.0, final_probs);
sum = last_alpha.LogSumExp();
alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1) = sum;
(*alpha)(alpha->NumRows() - 1, alpha->NumCols() - 1) = sum;

// second part of criterion
log_prob_product = sum - offsets_(seq);
Expand Down Expand Up @@ -242,7 +240,8 @@ bool GenericNumeratorComputation::ForwardBackward(
// Backward part
BetaLastFrame(seq, alpha, &beta);
BetaRemainingFrames(seq, probs, alpha, &beta, &derivs);
ok = ok || CheckValues(seq, probs, alpha, beta, derivs);
if (GetVerboseLevel() >= 1)
ok = ok && CheckValues(seq, probs, alpha, beta, derivs);
}
// Transfer and add the derivatives to the values in the matrix
AddSpecificPdfsIndirect(&derivs, index_to_pdf_, nnet_output_deriv);
Expand All @@ -268,7 +267,6 @@ BaseFloat GenericNumeratorComputation::ComputeObjf() {
return partial_loglike;
}


BaseFloat GenericNumeratorComputation::GetTotalProb(
const Matrix<BaseFloat> &alpha) {
return alpha(alpha.NumRows() - 1, alpha.NumCols() - 1);
Expand Down Expand Up @@ -306,36 +304,33 @@ void GenericNumeratorComputation::BetaRemainingFrames(int seq,
num_states = supervision_.e2e_fsts[seq].NumStates();
KALDI_ASSERT(seq >= 0 && seq < num_sequences);

SubMatrix<BaseFloat> log_prob_deriv(*derivs,
0, derivs->NumRows(),
0, derivs->NumCols());

for (int t = num_frames - 1; t >= 0; --t) {
SubVector<BaseFloat> this_beta(beta->RowData(t % 2), num_states);
const SubVector<BaseFloat> next_beta(beta->RowData((t + 1) % 2),
num_states);

BaseFloat inv_arbitrary_scale = alpha(t, num_states);
const BaseFloat *alpha_t = alpha.RowData(t),
*beta_tp1 = beta->RowData((t + 1) % 2),
*probs_t = probs.RowData(t);
BaseFloat *log_prob_deriv_t = derivs->RowData(t),
*beta_t = beta->RowData(t % 2);

BaseFloat inv_arbitrary_scale = alpha_t[num_states];
for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) {
BaseFloat tot_variable_factor;
tot_variable_factor = -std::numeric_limits<BaseFloat>::infinity();
for (auto tr = out_transitions_[seq][h].begin();
tr != out_transitions_[seq][h].end(); tr++) {
tr != out_transitions_[seq][h].end(); ++tr) {
BaseFloat transition_prob = tr->transition_prob;
int32 pdf_id = tr->pdf_id,
next_hmm_state = tr->hmm_state;
BaseFloat variable_factor = transition_prob +
next_beta(next_hmm_state) +
probs(t, pdf_id) - inv_arbitrary_scale;
beta_tp1[next_hmm_state] +
probs_t[pdf_id] - inv_arbitrary_scale;
tot_variable_factor = LogAdd(tot_variable_factor,
variable_factor);

BaseFloat occupation_prob = variable_factor + alpha(t, h);
log_prob_deriv(t, pdf_id) = LogAdd(log_prob_deriv(t, pdf_id),
BaseFloat occupation_prob = variable_factor + alpha_t[h];
log_prob_deriv_t[pdf_id] = LogAdd(log_prob_deriv_t[pdf_id],
occupation_prob);
}
this_beta(h) = tot_variable_factor;
beta_t[h] = tot_variable_factor;
}
}
}
Expand Down Expand Up @@ -381,7 +376,29 @@ bool GenericNumeratorComputation::CheckValues(int seq,
const Matrix<BaseFloat> &alpha,
const Matrix<BaseFloat> &beta,
const Matrix<BaseFloat> &derivs) const {
// empty checks for now
const int32 num_frames = supervision_.frames_per_sequence;
// only check the derivs for the first and last frames
const std::vector<int32> times = {0, num_frames - 1};
for (const int32 t: times) {
BaseFloat deriv_sum = 0.0;
for (int32 n = 0; n < probs.NumCols(); n++) {
int32 pdf_stride = nnet_output_.Stride();
int32 pdf2seq = index_to_pdf_[n] / pdf_stride;
if (pdf2seq != seq) // this pdf is not in the space of this sequence
continue;
deriv_sum += Exp(derivs(t, n));
}

if (!ApproxEqual(deriv_sum, 1.0)) {
KALDI_WARN << "On time " << t
<< " for seq " << seq << ", deriv sum "
<< deriv_sum << " != 1.0";
if (fabs(deriv_sum - 1.0) > 0.05 || deriv_sum - deriv_sum != 0) {
KALDI_WARN << "Excessive error detected, will abandon this minibatch";
return false;
}
}
}
return true;
}

Expand Down
5 changes: 2 additions & 3 deletions src/chain/chain-generic-numerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class GenericNumeratorComputation {
BaseFloat GetTotalProb(const Matrix<BaseFloat> &alpha);

// some checking that we can do if debug mode is activated, or on frame zero.
// Sets ok_ to false if a bad problem is detected.
// Returns false if a bad problem is detected.
bool CheckValues(int32 seq,
const Matrix<BaseFloat> &probs,
const Matrix<BaseFloat> &alpha,
Expand All @@ -196,8 +196,7 @@ class GenericNumeratorComputation {
Matrix<BaseFloat> final_probs_; // indexed by seq, state

// an offset subtracted from the logprobs of transitions out of the first
// state of each graph to help reduce numerical problems. Note the
// generic forward-backward computations cannot be done in log-space.
// state of each graph to help reduce numerical problems.
Vector<BaseFloat> offsets_;
};

Expand Down

0 comments on commit 775c770

Please sign in to comment.