Skip to content

Commit

Permalink
[src] bug-fixes for end2end chain code (kaldi-asr#2270)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhadian authored and danpovey committed Mar 12, 2018
1 parent 0360215 commit 19dc26f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/chain/chain-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts,
if (nnet_output_deriv)
nnet_output_deriv->AddMat(1.0, *xent_output_deriv);
} else if (nnet_output_deriv && numerator_ok) {
numerator.Backward(nnet_output_deriv);
numerator_ok = numerator.Backward(nnet_output_deriv);
if (!numerator_ok)
KALDI_LOG << "Numerator backward failed.";
}
}

Expand Down Expand Up @@ -128,9 +130,8 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts,
KALDI_LOG << "Derivs per frame are " << row_products_per_frame;
}

if (opts.l2_regularize == 0.0) {
*l2_term = 0.0;
} else if (numerator_ok) { // we should have some derivs to include a L2 term
*l2_term = 0.0;
if (opts.l2_regularize != 0.0 && numerator_ok) { // we should have some derivs to include a L2 term
// compute the l2 penalty term and its derivative
BaseFloat scale = supervision.weight * opts.l2_regularize;
*l2_term = -0.5 * scale * TraceMatMat(nnet_output, nnet_output, kTrans);
Expand Down

0 comments on commit 19dc26f

Please sign in to comment.