diff --git a/egs/swbd/s5c/local/chain/run_tdnn_a.sh b/egs/swbd/s5c/local/chain/run_tdnn_a.sh index 3f89db075a6..da68a6bad67 100755 --- a/egs/swbd/s5c/local/chain/run_tdnn_a.sh +++ b/egs/swbd/s5c/local/chain/run_tdnn_a.sh @@ -119,7 +119,8 @@ fi if [ $stage -le 13 ]; then # Note: it might appear that this $lang directory is mismatched, and it is as - # far as the 'topo' + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. utils/mkgraph.sh --transition-scale 0.0 \ --self-loop-scale 0.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg fi @@ -129,7 +130,7 @@ graph_dir=$dir/graph_sw1_tg if [ $stage -le 14 ]; then for decode_set in train_dev eval2000; do ( - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --iter 298_cached \ + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ --nj 50 --cmd "$decode_cmd" \ --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}_${decode_suff} || exit 1; @@ -138,7 +139,7 @@ if [ $stage -le 14 ]; then data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ $dir/decode_${decode_set}_sw1_{tg,fsh_fg} || exit 1; fi - ) # & + ) & done fi wait; diff --git a/egs/swbd/s5c/local/score_sclite.sh b/egs/swbd/s5c/local/score_sclite.sh index 2031e472fe5..556afbf02bc 100755 --- a/egs/swbd/s5c/local/score_sclite.sh +++ b/egs/swbd/s5c/local/score_sclite.sh @@ -51,18 +51,16 @@ nnet3-ctc-info --print-args=false $model 1>/dev/null 2>&1; [ $? -eq 0 ] && is_ctc=true; [ -z $is_ctc ] && echo "Unknown model type, verify if $model exists" && exit -1; align_word= -reorder= +reorder_opt= if $reverse; then align_word="lattice-reverse ark:- ark:- |" - reorder="--reorder=false" + reorder_opt="--reorder=false" fi if $is_ctc ; then - echo "Warning : This is a CTC model, using corresponding scoring pipeline." + echo "Warning : This is a 'chain' model, using corresponding scoring pipeline." factor=$(cat $dir/../frame_subsampling_factor) || exit 1 frame_shift_opt="--frame-shift=0.0$factor" -else - align_word="$align_word lattice-align-words $reorder $lang/phones/word_boundary.int $model ark:- ark:- |" fi name=`basename $data`; # e.g. eval2000 @@ -75,7 +73,8 @@ if [ $stage -le 0 ]; then mkdir -p $dir/score_LMWT_${wip}/ '&&' \ lattice-scale --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ - lattice-1best ark:- ark:- \| $align_word \ + lattice-1best ark:- ark:- \| \ + lattice-align-words $reorder_opt $lang/phones/word_boundary.int $model ark:- ark:- \| \ nbest-to-ctm $frame_shift_opt ark:- - \| \ utils/int2sym.pl -f 5 $lang/words.txt \| \ utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ diff --git a/egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh b/egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh index 0f24bebdc99..0ef57020970 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh +++ b/egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh @@ -182,7 +182,7 @@ if [ $stage -le -6 ]; then echo "$0: creating denominator FST" copy-transition-model $treedir/final.mdl $dir/0.trans_mdl $cmd $dir/log/make_den_fst.log \ - chain-make-den-graph $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \ + chain-make-den-fst $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \ $dir/den.fst $dir/normalization.fst || exit 1; fi diff --git a/egs/wsj/s5/steps/nnet3/decode.sh b/egs/wsj/s5/steps/nnet3/decode.sh index 17133fc88de..f4de09740ae 100755 --- a/egs/wsj/s5/steps/nnet3/decode.sh +++ b/egs/wsj/s5/steps/nnet3/decode.sh @@ -128,7 +128,7 @@ fi if [ ! -z "$online_ivector_dir" ]; then ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; - ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector_period=$ivector_period" + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" fi if [ "$post_decode_acwt" == 1.0 ]; then @@ -137,9 +137,14 @@ else lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" fi +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +fi + if [ $stage -le 1 ]; then $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ - nnet3-latgen-faster$thread_string $ivector_opts \ + nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \ diff --git a/src/chain/chain-den-graph.cc b/src/chain/chain-den-graph.cc index 68a7c88e888..1a8b6219a41 100644 --- a/src/chain/chain-den-graph.cc +++ b/src/chain/chain-den-graph.cc @@ -291,7 +291,6 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, KALDI_LOG << "Number of states and arcs in phone-LM FST is " << phone_lm.NumStates() << " and " << NumArcs(phone_lm); - int32 subsequential_symbol = trans_model.GetPhones().back() + 1; if (ctx_dep.CentralPosition() != ctx_dep.ContextWidth() - 1) { // note: this function only adds the subseq symbol to the input of what was @@ -316,14 +315,14 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, std::vector disambig_syms_h; // disambiguation symbols on input side // of H -- will be empty. - HTransducerConfig h_cfg; - h_cfg.transition_scale = 0.0; // we don't want transition probs. - h_cfg.push_weights = false; // there's nothing to push. + HTransducerConfig h_config; + h_config.transition_scale = 0.0; // we don't want transition probs. + h_config.push_weights = false; // there's nothing to push. StdVectorFst *h_fst = GetHTransducer(cfst.ILabelInfo(), ctx_dep, trans_model, - h_cfg, + h_config, &disambig_syms_h); KALDI_ASSERT(disambig_syms_h.empty()); StdVectorFst transition_id_fst; diff --git a/src/chain/chain-kernels.cu b/src/chain/chain-kernels.cu index 2330f2cc315..ca7c8faa792 100644 --- a/src/chain/chain-kernels.cu +++ b/src/chain/chain-kernels.cu @@ -40,9 +40,9 @@ __device__ inline void atomic_add_thresholded(Real* address, Real value) { // threshold itself with probability (value / threshold). This preserves // expectations. Note: we assume that value >= 0. - // you can chose any value for the threshold, but powers of 2 are nice + // you can choose any value for the threshold, but powers of 2 are nice // because they will exactly preserve the precision of the value. - const Real threshold = 1.0 / (1 << 16); + const Real threshold = 1.0 / (1 << 14); if (value >= threshold) { atomic_add(address, value); } else { diff --git a/src/chain/chain-supervision.cc b/src/chain/chain-supervision.cc index ab3b3e2f3d7..d6d2412d568 100644 --- a/src/chain/chain-supervision.cc +++ b/src/chain/chain-supervision.cc @@ -58,7 +58,7 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts, std::vector labels(phones.size()); int32 num_frames = std::accumulate(durations.begin(), durations.end(), 0), factor = opts.frame_subsampling_factor, - num_frames_subsampled = num_frames / factor; + num_frames_subsampled = (num_frames + factor - 1) / factor; proto_supervision->allowed_phones.clear(); proto_supervision->allowed_phones.resize(num_frames_subsampled); proto_supervision->fst.DeleteStates(); @@ -69,14 +69,15 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts, for (int32 i = 0; i < num_phones; i++) { int32 phone = phones[i], duration = durations[i]; KALDI_ASSERT(phone > 0 && duration > 0); - int32 t_start_subsampled = - std::max(0, - (current_frame - opts.left_tolerance) / factor), - t_end_subsampled = std::min( - num_frames_subsampled, - (current_frame + duration + opts.right_tolerance) / factor); + int32 t_start = std::max(0, (current_frame - opts.left_tolerance)), + t_end = std::min(num_frames, + (current_frame + duration + opts.right_tolerance)), + t_start_subsampled = (t_start + factor - 1) / factor, + t_end_subsampled = (t_end + factor - 1) / factor; + // note: if opts.Check() passed, the following assert should pass too. - KALDI_ASSERT(t_end_subsampled > t_start_subsampled); + KALDI_ASSERT(t_end_subsampled > t_start_subsampled && + t_end_subsampled <= num_frames_subsampled); for (int32 t_subsampled = t_start_subsampled; t_subsampled < t_end_subsampled; t_subsampled++) proto_supervision->allowed_phones[t_subsampled].push_back(phone); @@ -127,13 +128,7 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, std::vector state_times; int32 num_frames = CompactLatticeStateTimes(lat, &state_times), factor = opts.frame_subsampling_factor, - num_frames_subsampled = num_frames / factor; - if (num_frames < opts.frame_subsampling_factor) { - KALDI_WARN << "Number of frames in lattice " << num_frames - << " is less than --frame-subsampling-factor=" - << opts.frame_subsampling_factor; - return false; - } + num_frames_subsampled = (num_frames + factor - 1) / factor; for (int32 state = 0; state < num_states; state++) proto_supervision->fst.AddState(); proto_supervision->fst.SetStart(lat.Start()); @@ -156,12 +151,11 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, fst::StdArc(phone, phone, fst::TropicalWeight::One(), lat_arc.nextstate)); - int32 t_begin_subsampled = - std::max(0, - (state_time - opts.left_tolerance) / factor), - t_end_subsampled = std::min( - num_frames_subsampled, - (next_state_time + opts.right_tolerance) / factor); + int32 t_begin = std::max(0, (state_time - opts.left_tolerance)), + t_end = std::min(num_frames, + (next_state_time + opts.right_tolerance)), + t_begin_subsampled = (t_begin + factor - 1)/ factor, + t_end_subsampled = (t_end + factor - 1)/ factor; for (int32 t_subsampled = t_begin_subsampled; t_subsampled < t_end_subsampled; t_subsampled++) proto_supervision->allowed_phones[t_subsampled].push_back(phone); diff --git a/src/chain/language-model.cc b/src/chain/language-model.cc index 90c2fa8d900..61082a1c659 100644 --- a/src/chain/language-model.cc +++ b/src/chain/language-model.cc @@ -293,6 +293,7 @@ void LanguageModelEstimator::OutputToFst( int64 tot_den = std::accumulate(den_counts.begin(), den_counts.end(), 0), tot_num = 0; // for self-testing code. + double tot_logprob = 0.0; PairMapType::const_iterator iter = num_counts.begin(), end = num_counts.end(); @@ -304,6 +305,7 @@ void LanguageModelEstimator::OutputToFst( int32 den_count = den_counts[this_state]; KALDI_ASSERT(den_count >= num_count); BaseFloat prob = num_count / static_cast(den_count); + tot_logprob += num_count * log(prob); if (phone > 0) { // it's a real phone. find out where the transition is to. PairMapType::const_iterator @@ -320,9 +322,15 @@ void LanguageModelEstimator::OutputToFst( } KALDI_ASSERT(tot_num == tot_den); KALDI_LOG << "Total number of phone instances seen was " << tot_num; + BaseFloat perplexity = exp(-(tot_logprob / tot_num)); + KALDI_LOG << "Perplexity on training data is: " << perplexity; + KALDI_LOG << "Note: perplexity on unseen data will be infinity as there is " + << "no smoothing. This is by design, to reduce the number of arcs."; fst::Connect(fst); // Make sure that Connect does not delete any states. KALDI_ASSERT(fst->NumStates() == num_states); + // arc-sort. ilabel or olabel doesn't matter, it's an acceptor. + fst::ArcSort(fst, fst::ILabelCompare()); KALDI_LOG << "Created phone language model with " << num_states << " states."; } diff --git a/src/chainbin/chain-est-phone-lm.cc b/src/chainbin/chain-est-phone-lm.cc index b5936501d87..0d538064983 100644 --- a/src/chainbin/chain-est-phone-lm.cc +++ b/src/chainbin/chain-est-phone-lm.cc @@ -64,6 +64,7 @@ int main(int argc, char *argv[]) { const std::vector &phone_seq = phones_reader.Value(); lm_estimator.AddCounts(phone_seq); } + KALDI_LOG << "Estimating phone LM"; fst::StdVectorFst fst; lm_estimator.Estimate(&fst); diff --git a/src/lat/word-align-lattice.cc b/src/lat/word-align-lattice.cc index 51886e810f8..db22c0c85e4 100644 --- a/src/lat/word-align-lattice.cc +++ b/src/lat/word-align-lattice.cc @@ -28,12 +28,12 @@ class LatticeWordAligner { public: typedef CompactLatticeArc::StateId StateId; typedef CompactLatticeArc::Label Label; - + class ComputationState { /// The state of the computation in which, /// along a single path in the lattice, we work out the word /// boundaries and output aligned arcs. public: - + /// Advance the computation state by adding the symbols and weights /// from this arc. We'll put the weight on the output arc; this helps /// keep the state-space smaller. @@ -71,18 +71,18 @@ class LatticeWordAligner { bool OutputSilenceArc(const WordBoundaryInfo &info, const TransitionModel &tmodel, CompactLatticeArc *arc_out, - bool *error); + bool *error); bool OutputOnePhoneWordArc(const WordBoundaryInfo &info, const TransitionModel &tmodel, CompactLatticeArc *arc_out, - bool *error); + bool *error); bool OutputNormalWordArc(const WordBoundaryInfo &info, const TransitionModel &tmodel, CompactLatticeArc *arc_out, bool *error); - + bool IsEmpty() { return (transition_ids_.empty() && word_labels_.empty()); } - + /// FinalWeight() will return "weight" if both transition_ids /// and word_labels are empty, otherwise it will return /// Weight::Zero(). @@ -104,7 +104,7 @@ class LatticeWordAligner { const TransitionModel &tmodel, CompactLatticeArc *arc_out, bool *error); - + size_t Hash() const { VectorHasher vh; return vh(transition_ids_) + 90647 * vh(word_labels_); @@ -121,7 +121,7 @@ class LatticeWordAligner { && word_labels_ == other.word_labels_ && weight_ == other.weight_); } - + ComputationState(): weight_(LatticeWeight::One()) { } // initial state. ComputationState(const ComputationState &other): transition_ids_(other.transition_ids_), word_labels_(other.word_labels_), @@ -143,7 +143,7 @@ class LatticeWordAligner { struct TupleHash { size_t operator() (const Tuple &state) const { return state.input_state + 102763 * state.comp_state.Hash(); - // 102763 is just an arbitrary prime number + // 102763 is just an arbitrary prime number } }; struct TupleEqual { @@ -153,7 +153,7 @@ class LatticeWordAligner { && state1.comp_state == state2.comp_state); } }; - + typedef unordered_map MapType; StateId GetStateForTuple(const Tuple &tuple, bool add_to_queue) { @@ -168,17 +168,17 @@ class LatticeWordAligner { return iter->second; } } - + void ProcessFinal(Tuple tuple, StateId output_state) { // ProcessFinal is only called if the input_state has // final-prob of One(). [else it should be zero. This // is because we called CreateSuperFinal().] - + if (tuple.comp_state.IsEmpty()) { // computation state doesn't have // anything pending. std::vector empty_vec; CompactLatticeWeight cw(tuple.comp_state.FinalWeight(), empty_vec); - lat_out_->SetFinal(output_state, Plus(lat_out_->Final(output_state), cw)); + lat_out_->SetFinal(output_state, Plus(lat_out_->Final(output_state), cw)); } else { // computation state has something pending, i.e. input or // output symbols that need to be flushed out. Note: OutputArc() would @@ -197,7 +197,7 @@ class LatticeWordAligner { } } - + void ProcessQueueElement() { KALDI_ASSERT(!queue_.empty()); Tuple tuple = queue_.back().first; @@ -248,7 +248,7 @@ class LatticeWordAligner { } } } - + LatticeWordAligner(const CompactLattice &lat, const TransitionModel &tmodel, const WordBoundaryInfo &info, @@ -266,7 +266,7 @@ class LatticeWordAligner { } fst::CreateSuperFinal(&lat_); // Creates a super-final state, so the // only final-probs are One(). - + // Inside this class, we don't want to use zero for the silence // or partial-word labels, as this will interfere with the RmEpsilon // stage, where we don't want the arcs corresponding to silence or @@ -296,10 +296,10 @@ class LatticeWordAligner { syms_to_remove.push_back(info_.silence_label); if (!syms_to_remove.empty()) { RemoveSomeInputSymbols(syms_to_remove, lat_out_); - Project(lat_out_, fst::PROJECT_INPUT); + Project(lat_out_, fst::PROJECT_INPUT); } } - + bool AlignLattice() { lat_out_->DeleteStates(); if (lat_.Start() == fst::kNoStateId) { @@ -310,7 +310,7 @@ class LatticeWordAligner { Tuple initial_tuple(lat_.Start(), initial_comp_state); StateId start_state = GetStateForTuple(initial_tuple, true); // True = add this to queue. lat_out_->SetStart(start_state); - + while (!queue_.empty()) { if (max_states_ > 0 && lat_out_->NumStates() > max_states_) { KALDI_WARN << "Number of states in lattice exceeded max-states of " @@ -323,10 +323,10 @@ class LatticeWordAligner { } RemoveEpsilonsFromLattice(); - + return !error_; } - + CompactLattice lat_; const TransitionModel &tmodel_; const WordBoundaryInfo &info_in_; @@ -335,12 +335,12 @@ class LatticeWordAligner { CompactLattice *lat_out_; std::vector > queue_; - - - + + + MapType map_; // map from tuples to StateId. bool error_; - + }; bool LatticeWordAligner::ComputationState::OutputSilenceArc( @@ -355,7 +355,7 @@ bool LatticeWordAligner::ComputationState::OutputSilenceArc( size_t len = transition_ids_.size(), i; // Keep going till we reach a "final" transition-id; note, if // reorder==true, we have to go a bit further after this. - for (i = 1; i < len; i++) { + for (i = 0; i < len; i++) { int32 tid = transition_ids_[i]; int32 this_phone = tmodel.TransitionIdToPhone(tid); if (this_phone != phone && ! *error) { // error condition: should have reached final transition-id first. @@ -379,7 +379,7 @@ bool LatticeWordAligner::ComputationState::OutputSilenceArc( } // interpret i as the number of transition-ids to consume. std::vector tids_out(transition_ids_.begin(), transition_ids_.begin()+i); - + // consumed transition ids from our internal state. *arc_out = CompactLatticeArc(info.silence_label, info.silence_label, CompactLatticeWeight(weight_, tids_out), fst::kNoStateId); @@ -396,11 +396,11 @@ bool LatticeWordAligner::ComputationState::OutputOnePhoneWordArc( if (word_labels_.empty()) return false; int32 phone = tmodel.TransitionIdToPhone(transition_ids_[0]); if (info.TypeOfPhone(phone) != WordBoundaryInfo::kWordBeginAndEndPhone) - return false; + return false; // we assume the start of transition_ids_ is the start of the phone. // this is a precondition. size_t len = transition_ids_.size(), i; - for (i = 1; i < len; i++) { + for (i = 0; i < len; i++) { int32 tid = transition_ids_[i]; int32 this_phone = tmodel.TransitionIdToPhone(tid); if (this_phone != phone && ! *error) { // error condition: should have reached final transition-id first. @@ -416,17 +416,17 @@ bool LatticeWordAligner::ComputationState::OutputOnePhoneWordArc( if (info.reorder) // we have to consume the following self-loop transition-ids. while (i < len && tmodel.IsSelfLoop(transition_ids_[i])) i++; if (i == len) return false; // we don't know if it ends here... so can't output arc. - + if (tmodel.TransitionIdToPhone(transition_ids_[i-1]) != phone && ! *error) { // another check. KALDI_WARN << "Phone changed unexpectedly in lattice " "[broken lattice or mismatched model?]"; *error = true; } - + // interpret i as the number of transition-ids to consume. std::vector tids_out(transition_ids_.begin(), transition_ids_.begin()+i); - + // consumed transition ids from our internal state. int32 word = word_labels_[0]; *arc_out = CompactLatticeArc(word, word, @@ -447,7 +447,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc( if (word_labels_.empty()) return false; int32 begin_phone = tmodel.TransitionIdToPhone(transition_ids_[0]); if (info.TypeOfPhone(begin_phone) != WordBoundaryInfo::kWordBeginPhone) - return false; + return false; // we assume the start of transition_ids_ is the start of the phone. // this is a precondition. size_t len = transition_ids_.size(), i; @@ -488,7 +488,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc( // a "final-transition". // this variable just used for checks. - int32 final_phone = tmodel.TransitionIdToPhone(transition_ids_[i]); + int32 final_phone = tmodel.TransitionIdToPhone(transition_ids_[i]); for (; i < len; i++) { int32 this_phone = tmodel.TransitionIdToPhone(transition_ids_[i]); if (this_phone != final_phone && ! *error) { @@ -515,7 +515,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc( // OK, we're ready to output the word. // Interpret i as the number of transition-ids to consume. std::vector tids_out(transition_ids_.begin(), transition_ids_.begin()+i); - + // consumed transition ids from our internal state. int32 word = word_labels_[0]; *arc_out = CompactLatticeArc(word, word, @@ -550,7 +550,7 @@ static bool IsPlausibleWord(const WordBoundaryInfo &info, } else return false; } - + void LatticeWordAligner::ComputationState::OutputArcForce( const WordBoundaryInfo &info, const TransitionModel &tmodel, CompactLatticeArc *arc_out, bool *error) { @@ -560,7 +560,7 @@ void LatticeWordAligner::ComputationState::OutputArcForce( && !transition_ids_.empty()) { // We have at least one word to // output, and some transition-ids. We assume that the normal OutputArc was called // and failed, so this means we didn't see the end of that - // word. + // word. int32 word = word_labels_[0]; if (! *error && !IsPlausibleWord(info, tmodel, transition_ids_)) { *error = true; @@ -686,7 +686,7 @@ WordBoundaryInfo::WordBoundaryInfo(const WordBoundaryInfoNewOpts &opts, void WordBoundaryInfo::Init(std::istream &stream) { std::string line; while (std::getline(stream, line)) { - std::vector split_line; + std::vector split_line; SplitStringToVector(line, " \t\r", true, &split_line);// split the line by space or tab int32 p = 0; if (split_line.size() != 2 || @@ -701,13 +701,13 @@ void WordBoundaryInfo::Init(std::istream &stream) { else if (t == "singleton") phone_to_type[p] = kWordBeginAndEndPhone; else if (t == "end") phone_to_type[p] = kWordEndPhone; else if (t == "internal") phone_to_type[p] = kWordInternalPhone; - else + else KALDI_ERR << "Invalid line in word-boundary file: " << line; } if (phone_to_type.empty()) KALDI_ERR << "Empty word-boundary file"; } - + bool WordAlignLattice(const CompactLattice &lat, const TransitionModel &tmodel, const WordBoundaryInfo &info, @@ -726,7 +726,7 @@ class WordAlignedLatticeTester { const WordBoundaryInfo &info, const CompactLattice &aligned_lat): lat_(lat), tmodel_(tmodel), info_(info), aligned_lat_(aligned_lat) { } - + void Test() { // First test that each aligned arc is valid. typedef CompactLattice::StateId StateId ; @@ -766,7 +766,7 @@ class WordAlignedLatticeTester { return false; for (size_t i = 0; i < tids.size(); i++) if (tmodel_.TransitionIdToPhone(tids[i]) != first_phone) return false; - + if (!info_.reorder) return tmodel_.IsFinal(tids.back()); else { for (size_t i = 0; i < tids.size(); i++) { @@ -794,7 +794,7 @@ class WordAlignedLatticeTester { WordBoundaryInfo::kWordBeginAndEndPhone) return false; for (size_t i = 0; i < tids.size(); i++) if (tmodel_.TransitionIdToPhone(tids[i]) != first_phone) return false; - + if (!info_.reorder) return tmodel_.IsFinal(tids.back()); else { for (size_t i = 0; i < tids.size(); i++) { @@ -871,7 +871,7 @@ class WordAlignedLatticeTester { if (tids.empty()) return false; return true; // We're pretty liberal when it comes to partial words here. } - + void TestFinal(const CompactLatticeWeight &w) { if (!w.String().empty()) KALDI_ERR << "Expect to have no strings on final-weights of lattices."; @@ -890,14 +890,14 @@ class WordAlignedLatticeTester { KALDI_ERR << "Equivalence test failed (testing word-alignment of lattices.) " << "Make sure your model and lattices match!"; } - + const CompactLattice &lat_; const TransitionModel &tmodel_; const WordBoundaryInfo &info_; const CompactLattice &aligned_lat_; }; - - + + /// You should only test a lattice if WordAlignLattice returned true (i.e. it diff --git a/src/nnet3/nnet-am-decodable-simple.cc b/src/nnet3/nnet-am-decodable-simple.cc index ae1ada946c8..00dcde2047c 100644 --- a/src/nnet3/nnet-am-decodable-simple.cc +++ b/src/nnet3/nnet-am-decodable-simple.cc @@ -40,14 +40,17 @@ NnetDecodableBase::NnetDecodableBase( ivector_(ivector), online_ivector_feats_(online_ivectors), online_ivector_period_(online_ivector_period), compiler_(nnet_, opts_.optimize_config), - current_log_post_offset_(0) { + current_log_post_subsampled_offset_(0) { + num_subsampled_frames_ = + (feats_.NumRows() + opts_.frame_subsampling_factor - 1) / + opts_.frame_subsampling_factor; KALDI_ASSERT(IsSimpleNnet(nnet)); ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_); KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL)); KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 && "You need to set the --online-ivector-period option!")); log_priors_.ApplyLog(); - PossiblyWarnForFramesPerChunk(); + CheckAndFixConfigs(); } @@ -83,9 +86,9 @@ int32 NnetDecodableBase::GetIvectorDim() const { return 0; } -void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) { - KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows()); - +void NnetDecodableBase::EnsureFrameIsComputed(int32 subsampled_frame) { + KALDI_ASSERT(subsampled_frame >= 0 && + subsampled_frame < num_subsampled_frames_); int32 feature_dim = feats_.NumCols(), ivector_dim = GetIvectorDim(), nnet_input_dim = nnet_.InputDim("input"), @@ -98,30 +101,44 @@ void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) { KALDI_ERR << "Neural net expects 'ivector' features with dimension " << nnet_ivector_dim << " but you provided " << ivector_dim; - int32 current_frames_computed = current_log_post_.NumRows(), - current_offset = current_log_post_offset_; - KALDI_ASSERT(frame < current_offset || - frame >= current_offset + current_frames_computed); - // allow the output to be computed for frame 0 ... num_input_frames - 1. - int32 start_output_frame = frame, - num_output_frames = std::min(feats_.NumRows() - start_output_frame, - opts_.frames_per_chunk); - KALDI_ASSERT(num_output_frames > 0); + int32 current_subsampled_frames_computed = current_log_post_.NumRows(), + current_subsampled_offset = current_log_post_subsampled_offset_; + KALDI_ASSERT(subsampled_frame < current_subsampled_offset || + subsampled_frame >= current_subsampled_offset + + current_subsampled_frames_computed); + + // all subsampled frames pertain to the output of the network, + // they are output frames divided by opts_.frame_subsampling_factor. + int32 subsampling_factor = opts_.frame_subsampling_factor, + subsampled_frames_per_chunk = opts_.frames_per_chunk / subsampling_factor, + start_subsampled_frame = subsampled_frame, + num_subsampled_frames = std::min(num_subsampled_frames_ - + start_subsampled_frame, + subsampled_frames_per_chunk), + last_subsampled_frame = start_subsampled_frame + num_subsampled_frames - 1; + KALDI_ASSERT(num_subsampled_frames > 0); + // the output-frame numbers are the subsampled-frame numbers + int32 first_output_frame = start_subsampled_frame * subsampling_factor, + last_output_frame = last_subsampled_frame * subsampling_factor; + KALDI_ASSERT(opts_.extra_left_context >= 0); int32 left_context = nnet_left_context_ + opts_.extra_left_context; - int32 first_input_frame = start_output_frame - left_context, - num_input_frames = nnet_left_context_ + num_output_frames + - nnet_right_context_; + int32 first_input_frame = first_output_frame - left_context, + last_input_frame = last_output_frame + nnet_right_context_, + num_input_frames = last_input_frame + 1 - first_input_frame; + Vector ivector; - GetCurrentIvector(start_output_frame, num_output_frames, &ivector); + GetCurrentIvector(first_output_frame, + last_output_frame - first_output_frame, + &ivector); Matrix input_feats; if (first_input_frame >= 0 && - first_input_frame + num_input_frames <= feats_.NumRows()) { + last_input_frame < feats_.NumRows()) { SubMatrix input_feats(feats_.RowRange(first_input_frame, num_input_frames)); DoNnetComputation(first_input_frame, input_feats, ivector, - start_output_frame, num_output_frames); + first_output_frame, num_subsampled_frames); } else { Matrix feats_block(num_input_frames, feats_.NumCols()); int32 tot_input_feats = feats_.NumRows(); @@ -134,21 +151,25 @@ void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) { dest.CopyFromVec(src); } DoNnetComputation(first_input_frame, feats_block, ivector, - start_output_frame, num_output_frames); + first_output_frame, num_subsampled_frames); } } -void NnetDecodableBase::GetOutputForFrame(int32 frame, +// note: in the normal case (with no frame subsampling) you can ignore the +// 'subsampled_' in the variable name. +void NnetDecodableBase::GetOutputForFrame(int32 subsampled_frame, VectorBase *output) { - if (frame < current_log_post_offset_ || - frame >= current_log_post_offset_ + current_log_post_.NumRows()) - EnsureFrameIsComputed(frame); - output->CopyFromVec(current_log_post_.Row(frame - current_log_post_offset_)); + if (subsampled_frame < current_log_post_subsampled_offset_ || + subsampled_frame >= current_log_post_subsampled_offset_ + + current_log_post_.NumRows()) + EnsureFrameIsComputed(subsampled_frame); + output->CopyFromVec(current_log_post_.Row( + subsampled_frame - current_log_post_subsampled_offset_)); } void NnetDecodableBase::GetCurrentIvector(int32 output_t_start, - int32 num_output_frames, - Vector *ivector) { + int32 num_output_frames, + Vector *ivector) { if (ivector_ != NULL) { *ivector = *ivector_; return; @@ -185,7 +206,7 @@ void NnetDecodableBase::DoNnetComputation( const MatrixBase &input_feats, const VectorBase &ivector, int32 output_t_start, - int32 num_output_frames) { + int32 num_subsampled_frames) { ComputationRequest request; request.need_model_derivative = false; request.store_component_stats = false; @@ -205,9 +226,17 @@ void NnetDecodableBase::DoNnetComputation( indexes.push_back(Index(0, 0, 0)); request.inputs.push_back(IoSpecification("ivector", indexes)); } - request.outputs.push_back( - IoSpecification("output", time_offset + output_t_start, - time_offset + output_t_start + num_output_frames)); + IoSpecification output_spec; + output_spec.name = "output"; + output_spec.has_deriv = false; + int32 subsample = opts_.frame_subsampling_factor; + output_spec.indexes.resize(num_subsampled_frames); + // leave n and x values at 0 (the constructor sets these). + for (int32 i = 0; i < num_subsampled_frames; i++) + output_spec.indexes[i].t = time_offset + output_t_start + i * subsample; + request.outputs.resize(1); + request.outputs[0].Swap(&output_spec); + const NnetComputation *computation = compiler_.Compile(request); Nnet *nnet_to_update = NULL; // we're not doing any update. NnetComputer computer(opts_.compute_config, *computation, @@ -232,14 +261,31 @@ void NnetDecodableBase::DoNnetComputation( current_log_post_.Resize(0, 0); // the following statement just swaps the pointers if we're not using a GPU. cu_output.Swap(¤t_log_post_); - current_log_post_offset_ = output_t_start; + current_log_post_subsampled_offset_ = output_t_start / subsample; } -void NnetDecodableBase::PossiblyWarnForFramesPerChunk() const { - static bool warned = false; +void NnetDecodableBase::CheckAndFixConfigs() { + static bool warned_modulus = false, + warned_subsampling = false; int32 nnet_modulus = nnet_.Modulus(); - if (opts_.frames_per_chunk % nnet_modulus != 0 && !warned) { - warned = true; + if (opts_.frame_subsampling_factor < 1 || + opts_.frames_per_chunk < 1) + KALDI_ERR << "--frame-subsampling-factor and --frames-per-chunk must be > 0"; + if (opts_.frames_per_chunk % opts_.frame_subsampling_factor != 0) { + int32 f = opts_.frame_subsampling_factor, + frames_per_chunk = f * ((opts_.frames_per_chunk + f - 1) / f); + if (!warned_subsampling) { + warned_subsampling = true; + KALDI_LOG << "Increasing --frames-per-chunk from " + << opts_.frames_per_chunk << " to " + << frames_per_chunk << " to make it a multiple of " + << "--frame-subsampling-factor=" + << opts_.frame_subsampling_factor; + } + opts_.frames_per_chunk = frames_per_chunk; + } + if (opts_.frames_per_chunk % nnet_modulus != 0 && !warned_modulus) { + warned_modulus = true; KALDI_WARN << "It may be more efficient to set the --frames-per-chunk " << "(currently " << opts_.frames_per_chunk << " to a " << "multiple of the network's shift-invariance modulus " @@ -249,4 +295,4 @@ void NnetDecodableBase::PossiblyWarnForFramesPerChunk() const { } // namespace nnet3 } // namespace kaldi - + diff --git a/src/nnet3/nnet-am-decodable-simple.h b/src/nnet3/nnet-am-decodable-simple.h index 4eeb262c787..15399b1308d 100644 --- a/src/nnet3/nnet-am-decodable-simple.h +++ b/src/nnet3/nnet-am-decodable-simple.h @@ -37,6 +37,7 @@ namespace nnet3 { // for which IsSimpleNnet(nnet) would return true. struct NnetSimpleComputationOptions { int32 extra_left_context; + int32 frame_subsampling_factor; int32 frames_per_chunk; BaseFloat acoustic_scale; bool debug_computation; @@ -45,6 +46,7 @@ struct NnetSimpleComputationOptions { NnetSimpleComputationOptions(): extra_left_context(0), + frame_subsampling_factor(1), frames_per_chunk(50), acoustic_scale(0.1), debug_computation(false) { } @@ -54,11 +56,17 @@ struct NnetSimpleComputationOptions { "Number of frames of additional left-context to add on top " "of the neural net's inherent left context (may be useful in " "recurrent setups"); + opts->Register("frame-subsampling-factor", &frame_subsampling_factor, + "Required if the frame-rate of the output (e.g. in 'chain' " + "models) is less than the frame-rate of the original " + "alignment."); opts->Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic log-likelihoods"); opts->Register("frames-per-chunk", &frames_per_chunk, "Number of frames in each chunk that is separately evaluated " - "by the neural net."); + "by the neural net. Measured before any subsampling, if the " + "--frame-subsampling-factor options is used (i.e. counts " + "input frames"); opts->Register("debug-computation", &debug_computation, "If true, turn on " "debug for the actual computation (very verbose!)"); @@ -115,8 +123,10 @@ class NnetDecodableBase { int32 online_ivector_period = 1); - // returns the number of frames of likelihoods. - inline int32 NumFrames() const { return feats_.NumRows(); } + // returns the number of frames of likelihoods. The same as feats_.NumRows() + // in the normal case (but may be less if opts_.frame_subsampling_factor != + // 1). + inline int32 NumFrames() const { return num_subsampled_frames_; } inline int32 OutputDim() const { return output_dim_; } @@ -124,19 +134,22 @@ class NnetDecodableBase { // 'output' must be correctly sized (with dimension OutputDim()). void GetOutputForFrame(int32 frame, VectorBase *output); - // Gets the output for a particular frame and pdf_id, with 0 <= frame < NumFrames(), + // Gets the output for a particular frame and pdf_id, with + // 0 <= subsampled_frame < NumFrames(), // and 0 <= pdf_id < OutputDim(). - inline BaseFloat GetOutput(int32 frame, int32 pdf_id) { - if (frame < current_log_post_offset_ || - frame >= current_log_post_offset_ + current_log_post_.NumRows()) - EnsureFrameIsComputed(frame); - return current_log_post_(frame - current_log_post_offset_, + inline BaseFloat GetOutput(int32 subsampled_frame, int32 pdf_id) { + if (subsampled_frame < current_log_post_subsampled_offset_ || + subsampled_frame >= current_log_post_subsampled_offset_ + + current_log_post_.NumRows()) + EnsureFrameIsComputed(subsampled_frame); + return current_log_post_(subsampled_frame - + current_log_post_subsampled_offset_, pdf_id); } private: // This call is made to ensure that we have the log-probs for this frame // cached in current_log_post_. - void EnsureFrameIsComputed(int32 frame); + void EnsureFrameIsComputed(int32 subsampled_frame); // This function does the actual nnet computation; it is called from // EnsureFrameIsComputed. Any padding at file start/end is done by @@ -146,19 +159,24 @@ class NnetDecodableBase { const MatrixBase &input_feats, const VectorBase &ivector, int32 output_t_start, - int32 num_output_frames); - - // Gets the iVector that will be used for this chunk of frames, if - // we are using iVectors (else does nothing). - void GetCurrentIvector(int32 output_t_start, int32 num_output_frames, + int32 num_subsampled_frames); + + // Gets the iVector that will be used for this chunk of frames, if we are + // using iVectors (else does nothing). note: the num_output_frames is + // interpreted as the number of t value, which in the subsampled case is not + // the same as the number of subsampled frames (it would be larger by + // opts_.frame_subsampling_factor). + void GetCurrentIvector(int32 output_t_start, + int32 num_output_frames, Vector *ivector); - void PossiblyWarnForFramesPerChunk() const; + // called from constructor + void CheckAndFixConfigs(); // returns dimension of the provided iVectors if supplied, or 0 otherwise. int32 GetIvectorDim() const; - const NnetSimpleComputationOptions &opts_; + NnetSimpleComputationOptions opts_; const Nnet &nnet_; int32 nnet_left_context_; int32 nnet_right_context_; @@ -166,6 +184,9 @@ class NnetDecodableBase { // the log priors (or the empty vector if the priors are not set in the model) CuVector log_priors_; const MatrixBase &feats_; + // note: num_subsampled_frames_ will equal feats_.NumRows() in the normal case + // when opts_.frame_subsampling_factor == 1. + int32 num_subsampled_frames_; // ivector_ is the iVector if we're using iVectors that are estimated in batch // mode. @@ -182,8 +203,10 @@ class NnetDecodableBase { // The current log-posteriors that we got from the last time we // ran the computation. Matrix current_log_post_; - // The time-offset of the current log-posteriors. - int32 current_log_post_offset_; + // The time-offset of the current log-posteriors. Note: if + // opts_.frame_subsampling_factor > 1, this will be measured in subsampled + // frames. + int32 current_log_post_subsampled_offset_; }; diff --git a/src/nnet3/nnet-cctc-decodable-simple.cc b/src/nnet3/nnet-cctc-decodable-simple.cc index ab09879c9c3..bb61f33e6b0 100644 --- a/src/nnet3/nnet-cctc-decodable-simple.cc +++ b/src/nnet3/nnet-cctc-decodable-simple.cc @@ -145,7 +145,7 @@ void DecodableNnetCctcSimple::EnsureFrameIsComputed(int32 subsampled_frame) { current_subsampled_frames_computed); // all subsampled frames pertain to the output of the network, - // they are output frames divided by opts_.frame_subsampled_factor. + // they are output frames divided by opts_.frame_subsampling_factor. int32 subsampling_factor = opts_.frame_subsampling_factor, subsampled_frames_per_chunk = opts_.frames_per_chunk / subsampling_factor, start_subsampled_frame = subsampled_frame, diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index b9b1b5ad282..eab486ade85 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -261,7 +261,7 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const { std::string NormalizeComponent::Info() const { std::stringstream stream; stream << NonlinearComponent::Info(); - stream << ", target_rms=" << target_rms_; + stream << ", target-rms=" << target_rms_; return stream.str(); }