Skip to content

Commit

Permalink
[src] Make user parameter check ERROR not ASSERT (kaldi-asr#4181)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkm000 committed Jul 21, 2020
1 parent ef4b376 commit bd2521c
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/chain/chain-generic-numerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,21 @@ GenericNumeratorComputation::GenericNumeratorComputation(
supervision_(supervision),
nnet_output_(nnet_output),
opts_(opts) {
KALDI_ASSERT(supervision.num_sequences *
supervision.frames_per_sequence == nnet_output.NumRows() &&
supervision.label_dim == nnet_output.NumCols());
if (supervision.num_sequences * supervision.frames_per_sequence !=
nnet_output.NumRows()) {
KALDI_ERR << "Dimension mismatch: nnet_output number of rows "
<< nnet_output.NumRows() << " must be equal to the total number "
<< "of frames in the supervision object: (num_sequences="
<< supervision.num_sequences << ") * (frames_per_sequence="
<< supervision.frames_per_sequence << ") = "
<< supervision.num_sequences * supervision.frames_per_sequence;
}
if (supervision.label_dim != nnet_output.NumCols()) {
KALDI_ERR << "Dimension mismatch: nnet_output number of columns "
<< nnet_output.NumCols() << " is not equal to the supervision "
<< "object feature dimension label_dim=" << supervision.label_dim;
}

NVTX_RANGE(__func__);

using std::vector;
Expand Down Expand Up @@ -242,13 +254,13 @@ bool GenericNumeratorComputation::ForwardBackward(
unsigned int nthreads = opts_.num_threads > 0 ? opts_.num_threads :
std::thread::hardware_concurrency();
// Naive load balancing, each thread gets a chunk of the sequences to process
unsigned int num_sequences_per_thread =
unsigned int num_sequences_per_thread =
(num_sequences + nthreads - 1) / nthreads;

// Allocate one alpha and beta matrix per thread to avoid contention
std::vector<Matrix<BaseFloat>> alpha(nthreads);
std::vector<Matrix<BaseFloat>> beta(nthreads);

// Per thread partial values and boolean
std::vector<BaseFloat> partial_loglike_mt(nthreads, static_cast<BaseFloat>(0));
std::vector<bool> ok_mt(nthreads, true);
Expand Down

0 comments on commit bd2521c

Please sign in to comment.