Skip to content

Commit

Permalink
chain branch: various changes to get decoding working; various mostly…
Browse files Browse the repository at this point in the history
… minor fixes; change threshold in randomized pruning for beta.
  • Loading branch information
danpovey committed Nov 22, 2015
1 parent d7bd924 commit 0398ac2
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 146 deletions.
7 changes: 4 additions & 3 deletions egs/swbd/s5c/local/chain/run_tdnn_a.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ fi

if [ $stage -le 13 ]; then
# Note: it might appear that this $lang directory is mismatched, and it is as
# far as the 'topo'
# far as the 'topo' is concerned, but this script doesn't read the 'topo' from
# the lang directory.
utils/mkgraph.sh --transition-scale 0.0 \
--self-loop-scale 0.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg
fi
Expand All @@ -129,7 +130,7 @@ graph_dir=$dir/graph_sw1_tg
if [ $stage -le 14 ]; then
for decode_set in train_dev eval2000; do
(
steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --iter 298_cached \
steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \
--nj 50 --cmd "$decode_cmd" \
--online-ivector-dir exp/nnet3/ivectors_${decode_set} \
$graph_dir data/${decode_set}_hires $dir/decode_${decode_set}_${decode_suff} || exit 1;
Expand All @@ -138,7 +139,7 @@ if [ $stage -le 14 ]; then
data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \
$dir/decode_${decode_set}_sw1_{tg,fsh_fg} || exit 1;
fi
) # &
) &
done
fi
wait;
Expand Down
11 changes: 5 additions & 6 deletions egs/swbd/s5c/local/score_sclite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,16 @@ nnet3-ctc-info --print-args=false $model 1>/dev/null 2>&1;
[ $? -eq 0 ] && is_ctc=true;
[ -z $is_ctc ] && echo "Unknown model type, verify if $model exists" && exit -1;
align_word=
reorder=
reorder_opt=
if $reverse; then
align_word="lattice-reverse ark:- ark:- |"
reorder="--reorder=false"
reorder_opt="--reorder=false"
fi

if $is_ctc ; then
echo "Warning : This is a CTC model, using corresponding scoring pipeline."
echo "Warning : This is a 'chain' model, using corresponding scoring pipeline."
factor=$(cat $dir/../frame_subsampling_factor) || exit 1
frame_shift_opt="--frame-shift=0.0$factor"
else
align_word="$align_word lattice-align-words $reorder $lang/phones/word_boundary.int $model ark:- ark:- |"
fi

name=`basename $data`; # e.g. eval2000
Expand All @@ -75,7 +73,8 @@ if [ $stage -le 0 ]; then
mkdir -p $dir/score_LMWT_${wip}/ '&&' \
lattice-scale --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \
lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \
lattice-1best ark:- ark:- \| $align_word \
lattice-1best ark:- ark:- \| \
lattice-align-words $reorder_opt $lang/phones/word_boundary.int $model ark:- ark:- \| \
nbest-to-ctm $frame_shift_opt ark:- - \| \
utils/int2sym.pl -f 5 $lang/words.txt \| \
utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \
Expand Down
2 changes: 1 addition & 1 deletion egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ if [ $stage -le -6 ]; then
echo "$0: creating denominator FST"
copy-transition-model $treedir/final.mdl $dir/0.trans_mdl
$cmd $dir/log/make_den_fst.log \
chain-make-den-graph $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \
chain-make-den-fst $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \
$dir/den.fst $dir/normalization.fst || exit 1;
fi

Expand Down
9 changes: 7 additions & 2 deletions egs/wsj/s5/steps/nnet3/decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fi

if [ ! -z "$online_ivector_dir" ]; then
ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1;
ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector_period=$ivector_period"
ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period"
fi

if [ "$post_decode_acwt" == 1.0 ]; then
Expand All @@ -137,9 +137,14 @@ else
lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz"
fi

if [ -f $srcdir/frame_subsampling_factor ]; then
# e.g. for 'chain' systems
frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)"
fi

if [ $stage -le 1 ]; then
$cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \
nnet3-latgen-faster$thread_string $ivector_opts \
nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \
--frames-per-chunk=$frames_per_chunk \
--minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \
--lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \
Expand Down
9 changes: 4 additions & 5 deletions src/chain/chain-den-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep,
KALDI_LOG << "Number of states and arcs in phone-LM FST is "
<< phone_lm.NumStates() << " and " << NumArcs(phone_lm);


int32 subsequential_symbol = trans_model.GetPhones().back() + 1;
if (ctx_dep.CentralPosition() != ctx_dep.ContextWidth() - 1) {
// note: this function only adds the subseq symbol to the input of what was
Expand All @@ -316,14 +315,14 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep,

std::vector<int32> disambig_syms_h; // disambiguation symbols on input side
// of H -- will be empty.
HTransducerConfig h_cfg;
h_cfg.transition_scale = 0.0; // we don't want transition probs.
h_cfg.push_weights = false; // there's nothing to push.
HTransducerConfig h_config;
h_config.transition_scale = 0.0; // we don't want transition probs.
h_config.push_weights = false; // there's nothing to push.

StdVectorFst *h_fst = GetHTransducer(cfst.ILabelInfo(),
ctx_dep,
trans_model,
h_cfg,
h_config,
&disambig_syms_h);
KALDI_ASSERT(disambig_syms_h.empty());
StdVectorFst transition_id_fst;
Expand Down
4 changes: 2 additions & 2 deletions src/chain/chain-kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ __device__ inline void atomic_add_thresholded(Real* address, Real value) {
// threshold itself with probability (value / threshold). This preserves
// expectations. Note: we assume that value >= 0.

// you can chose any value for the threshold, but powers of 2 are nice
// you can choose any value for the threshold, but powers of 2 are nice
// because they will exactly preserve the precision of the value.
const Real threshold = 1.0 / (1 << 16);
const Real threshold = 1.0 / (1 << 14);
if (value >= threshold) {
atomic_add(address, value);
} else {
Expand Down
36 changes: 15 additions & 21 deletions src/chain/chain-supervision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts,
std::vector<int32> labels(phones.size());
int32 num_frames = std::accumulate(durations.begin(), durations.end(), 0),
factor = opts.frame_subsampling_factor,
num_frames_subsampled = num_frames / factor;
num_frames_subsampled = (num_frames + factor - 1) / factor;
proto_supervision->allowed_phones.clear();
proto_supervision->allowed_phones.resize(num_frames_subsampled);
proto_supervision->fst.DeleteStates();
Expand All @@ -69,14 +69,15 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts,
for (int32 i = 0; i < num_phones; i++) {
int32 phone = phones[i], duration = durations[i];
KALDI_ASSERT(phone > 0 && duration > 0);
int32 t_start_subsampled =
std::max<int32>(0,
(current_frame - opts.left_tolerance) / factor),
t_end_subsampled = std::min<int32>(
num_frames_subsampled,
(current_frame + duration + opts.right_tolerance) / factor);
int32 t_start = std::max<int32>(0, (current_frame - opts.left_tolerance)),
t_end = std::min<int32>(num_frames,
(current_frame + duration + opts.right_tolerance)),
t_start_subsampled = (t_start + factor - 1) / factor,
t_end_subsampled = (t_end + factor - 1) / factor;

// note: if opts.Check() passed, the following assert should pass too.
KALDI_ASSERT(t_end_subsampled > t_start_subsampled);
KALDI_ASSERT(t_end_subsampled > t_start_subsampled &&
t_end_subsampled <= num_frames_subsampled);
for (int32 t_subsampled = t_start_subsampled;
t_subsampled < t_end_subsampled; t_subsampled++)
proto_supervision->allowed_phones[t_subsampled].push_back(phone);
Expand Down Expand Up @@ -127,13 +128,7 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts,
std::vector<int32> state_times;
int32 num_frames = CompactLatticeStateTimes(lat, &state_times),
factor = opts.frame_subsampling_factor,
num_frames_subsampled = num_frames / factor;
if (num_frames < opts.frame_subsampling_factor) {
KALDI_WARN << "Number of frames in lattice " << num_frames
<< " is less than --frame-subsampling-factor="
<< opts.frame_subsampling_factor;
return false;
}
num_frames_subsampled = (num_frames + factor - 1) / factor;
for (int32 state = 0; state < num_states; state++)
proto_supervision->fst.AddState();
proto_supervision->fst.SetStart(lat.Start());
Expand All @@ -156,12 +151,11 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts,
fst::StdArc(phone, phone,
fst::TropicalWeight::One(),
lat_arc.nextstate));
int32 t_begin_subsampled =
std::max<int32>(0,
(state_time - opts.left_tolerance) / factor),
t_end_subsampled = std::min<int32>(
num_frames_subsampled,
(next_state_time + opts.right_tolerance) / factor);
int32 t_begin = std::max<int32>(0, (state_time - opts.left_tolerance)),
t_end = std::min<int32>(num_frames,
(next_state_time + opts.right_tolerance)),
t_begin_subsampled = (t_begin + factor - 1)/ factor,
t_end_subsampled = (t_end + factor - 1)/ factor;
for (int32 t_subsampled = t_begin_subsampled;
t_subsampled < t_end_subsampled; t_subsampled++)
proto_supervision->allowed_phones[t_subsampled].push_back(phone);
Expand Down
8 changes: 8 additions & 0 deletions src/chain/language-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ void LanguageModelEstimator::OutputToFst(
int64 tot_den = std::accumulate(den_counts.begin(),
den_counts.end(), 0),
tot_num = 0; // for self-testing code.
double tot_logprob = 0.0;

PairMapType::const_iterator
iter = num_counts.begin(), end = num_counts.end();
Expand All @@ -304,6 +305,7 @@ void LanguageModelEstimator::OutputToFst(
int32 den_count = den_counts[this_state];
KALDI_ASSERT(den_count >= num_count);
BaseFloat prob = num_count / static_cast<BaseFloat>(den_count);
tot_logprob += num_count * log(prob);
if (phone > 0) {
// it's a real phone. find out where the transition is to.
PairMapType::const_iterator
Expand All @@ -320,9 +322,15 @@ void LanguageModelEstimator::OutputToFst(
}
KALDI_ASSERT(tot_num == tot_den);
KALDI_LOG << "Total number of phone instances seen was " << tot_num;
BaseFloat perplexity = exp(-(tot_logprob / tot_num));
KALDI_LOG << "Perplexity on training data is: " << perplexity;
KALDI_LOG << "Note: perplexity on unseen data will be infinity as there is "
<< "no smoothing. This is by design, to reduce the number of arcs.";
fst::Connect(fst);
// Make sure that Connect does not delete any states.
KALDI_ASSERT(fst->NumStates() == num_states);
// arc-sort. ilabel or olabel doesn't matter, it's an acceptor.
fst::ArcSort(fst, fst::ILabelCompare<fst::StdArc>());
KALDI_LOG << "Created phone language model with " << num_states << " states.";
}

Expand Down
1 change: 1 addition & 0 deletions src/chainbin/chain-est-phone-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ int main(int argc, char *argv[]) {
const std::vector<int32> &phone_seq = phones_reader.Value();
lm_estimator.AddCounts(phone_seq);
}
KALDI_LOG << "Estimating phone LM";
fst::StdVectorFst fst;
lm_estimator.Estimate(&fst);

Expand Down
Loading

0 comments on commit 0398ac2

Please sign in to comment.