forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[scripts,egs] Support averaging forward and backward RNNLMs (kaldi-as…
- Loading branch information
Showing
6 changed files
with
355 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tuning/run_tdnn_lstm_back_1e.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
egs/swbd/s5c/local/rnnlm/tuning/run_tdnn_lstm_back_1e.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#!/bin/bash | ||
|
||
# Copyright 2012 Johns Hopkins University (author: Daniel Povey) | ||
# 2015 Guoguo Chen | ||
# 2017 Hainan Xu | ||
# 2017 Xiaohui Zhang | ||
|
||
# This script trains a backward LMs on the swbd LM-training data, and use it | ||
# to rescore either decoded lattices, or lattices that are just rescored with | ||
# a forward RNNLM. In order to run this, you must first run the forward RNNLM | ||
# recipe at local/rnnlm/run_tdnn_lstm.sh | ||
|
||
# rnnlm/train_rnnlm.sh: best iteration (out of 35) was 34, linking it to final iteration. | ||
# rnnlm/train_rnnlm.sh: train/dev perplexity was 41.8 / 55.1. | ||
# Train objf: -5.18 -4.46 -4.26 -4.18 -4.12 -4.07 -4.04 -4.00 -3.99 -3.98 -3.95 -3.93 -3.91 -3.90 -3.88 -3.87 -3.86 -3.85 -3.83 -3.82 -3.82 -3.81 -3.79 -3.79 -3.78 -3.77 -3.76 -3.77 -3.75 -3.74 -3.74 -3.73 -3.72 -3.71 -3.71 | ||
# Dev objf: -10.32 -4.89 -4.57 -4.45 -4.37 -4.33 -4.29 -4.26 -4.24 -4.22 -4.18 -4.17 -4.15 -4.14 -4.13 -4.12 -4.11 -4.10 -4.09 -4.08 -4.07 -4.06 -4.06 -4.05 -4.05 -4.05 -4.04 -4.04 -4.03 -4.03 -4.02 -4.02 -4.02 -4.01 -4.01 | ||
|
||
# %WER 11.1 | 1831 21395 | 89.9 6.4 3.7 1.0 11.1 46.3 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped/score_13_0.0/eval2000_hires.ctm.swbd.filt.sys | ||
# %WER 9.9 | 1831 21395 | 91.0 5.8 3.2 0.9 9.9 43.2 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys | ||
# %WER 9.5 | 1831 21395 | 91.4 5.5 3.1 0.9 9.5 42.5 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e_back/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys | ||
|
||
# %WER 15.9 | 4459 42989 | 85.7 9.7 4.6 1.6 15.9 51.6 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped/score_10_0.0/eval2000_hires.ctm.filt.sys | ||
# %WER 14.4 | 4459 42989 | 87.0 8.7 4.3 1.5 14.4 49.4 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e/score_11_0.0/eval2000_hires.ctm.filt.sys | ||
# %WER 13.9 | 4459 42989 | 87.6 8.4 4.0 1.5 13.9 48.6 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e_back/score_10_0.0/eval2000_hires.ctm.filt.sys | ||
|
||
# Begin configuration section. | ||
|
||
dir=exp/rnnlm_lstm_1e_backward | ||
embedding_dim=1024 | ||
lstm_rpd=256 | ||
lstm_nrpd=256 | ||
stage=-10 | ||
train_stage=-10 | ||
|
||
# variables for lattice rescoring | ||
run_lat_rescore=true | ||
ac_model_dir=exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp | ||
decode_dir_suffix_forward=rnnlm_1e | ||
decode_dir_suffix_backward=rnnlm_1e_back | ||
ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order | ||
# if it's set, it merges histories in the lattice if they share | ||
# the same ngram history and this prevents the lattice from | ||
# exploding exponentially | ||
|
||
. ./cmd.sh | ||
. ./utils/parse_options.sh | ||
|
||
text=data/train_nodev/text | ||
fisher_text=data/local/lm/fisher/text1.gz | ||
lexicon=data/local/dict_nosp/lexiconp.txt | ||
text_dir=data/rnnlm/text_nosp_1e_back | ||
mkdir -p $dir/config | ||
set -e | ||
|
||
for f in $text $lexicon; do | ||
[ ! -f $f ] && \ | ||
echo "$0: expected file $f to exist; search for local/wsj_extend_dict.sh in run.sh" && exit 1 | ||
done | ||
|
||
if [ $stage -le 0 ]; then | ||
mkdir -p $text_dir | ||
echo -n >$text_dir/dev.txt | ||
# hold out one in every 50 lines as dev data. | ||
cat $text | cut -d ' ' -f2- | awk '{for(i=NF;i>0;i--) printf("%s ", $i); print""}' | awk -v text_dir=$text_dir '{if(NR%50 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$text_dir/swbd.txt | ||
cat > $dir/config/hesitation_mapping.txt <<EOF | ||
hmm hum | ||
mmm um | ||
mm um | ||
mhm um-hum | ||
EOF | ||
gunzip -c $fisher_text | awk 'NR==FNR{a[$1]=$2;next}{for (n=1;n<=NF;n++) if ($n in a) $n=a[$n];print $0}' \ | ||
$dir/config/hesitation_mapping.txt - | awk '{for(i=NF;i>0;i--) printf("%s ", $i); print""}' > $text_dir/fisher.txt | ||
fi | ||
|
||
if [ $stage -le 1 ]; then | ||
cp data/lang/words.txt $dir/config/ | ||
n=`cat $dir/config/words.txt | wc -l` | ||
echo "<brk> $n" >> $dir/config/words.txt | ||
|
||
# words that are not present in words.txt but are in the training or dev data, will be | ||
# mapped to <SPOKEN_NOISE> during training. | ||
echo "<unk>" >$dir/config/oov.txt | ||
|
||
cat > $dir/config/data_weights.txt <<EOF | ||
swbd 3 1.0 | ||
fisher 1 1.0 | ||
EOF | ||
|
||
rnnlm/get_unigram_probs.py --vocab-file=$dir/config/words.txt \ | ||
--unk-word="<unk>" \ | ||
--data-weights-file=$dir/config/data_weights.txt \ | ||
$text_dir | awk 'NF==2' >$dir/config/unigram_probs.txt | ||
|
||
# choose features | ||
rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ | ||
--use-constant-feature=true \ | ||
--special-words='<s>,</s>,<brk>,<unk>,[noise],[laughter],[vocalized-noise]' \ | ||
$dir/config/words.txt > $dir/config/features.txt | ||
|
||
cat >$dir/config/xconfig <<EOF | ||
input dim=$embedding_dim name=input | ||
relu-renorm-layer name=tdnn1 dim=$embedding_dim input=Append(0, IfDefined(-1)) | ||
fast-lstmp-layer name=lstm1 cell-dim=$embedding_dim recurrent-projection-dim=$lstm_rpd non-recurrent-projection-dim=$lstm_nrpd | ||
relu-renorm-layer name=tdnn2 dim=$embedding_dim input=Append(0, IfDefined(-3)) | ||
fast-lstmp-layer name=lstm2 cell-dim=$embedding_dim recurrent-projection-dim=$lstm_rpd non-recurrent-projection-dim=$lstm_nrpd | ||
relu-renorm-layer name=tdnn3 dim=$embedding_dim input=Append(0, IfDefined(-3)) | ||
output-layer name=output include-log-softmax=false dim=$embedding_dim | ||
EOF | ||
rnnlm/validate_config_dir.sh $text_dir $dir/config | ||
fi | ||
|
||
if [ $stage -le 2 ]; then | ||
rnnlm/prepare_rnnlm_dir.sh $text_dir $dir/config $dir | ||
fi | ||
|
||
if [ $stage -le 3 ]; then | ||
rnnlm/train_rnnlm.sh --num-jobs-initial 1 --num-jobs-final 3 \ | ||
--stage $train_stage --num-epochs 10 --cmd "$train_cmd" $dir | ||
fi | ||
|
||
LM=sw1_fsh_fg # using the 4-gram const arpa file as old lm | ||
if [ $stage -le 4 ] && $run_lat_rescore; then | ||
echo "$0: Perform lattice-rescoring on $ac_model_dir" | ||
|
||
for decode_set in eval2000; do | ||
decode_dir=${ac_model_dir}/decode_${decode_set}_${LM}_looped | ||
if [ ! -d ${decode_dir}_${decode_dir_suffix_forward} ]; then | ||
echo "$0: Must run the forward recipe first at local/rnnlm/run_tdnn_lstm.sh" | ||
exit 1 | ||
fi | ||
|
||
# Lattice rescoring | ||
rnnlm/lmrescore_back.sh \ | ||
--cmd "$decode_cmd --mem 4G" \ | ||
--weight 0.45 --max-ngram-order $ngram_order \ | ||
data/lang_$LM $dir \ | ||
data/${decode_set}_hires ${decode_dir}_${decode_dir_suffix_forward}_0.45 \ | ||
${decode_dir}_${decode_dir_suffix_backward}_0.45 | ||
done | ||
fi | ||
|
||
exit 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#!/bin/bash | ||
|
||
# Copyright 2017 Hainan Xu | ||
# Apache 2.0 | ||
|
||
# This script rescores lattices with KALDI RNNLM trained on reversed text. | ||
# The input directory should already be rescored with a forward RNNLM, preferably | ||
# with the pruned algorithm, since smaller lattices make rescoring much faster. | ||
# An example of the forward pruned rescoring is at | ||
# egs/swbd/s5c/local/rnnlm/run_tdnn_lstm.sh | ||
# One example script for backward RNNLM rescoring is at | ||
# egs/swbd/s5c/local/rnnlm/run_tdnn_lstm_back.sh | ||
|
||
# Begin configuration section. | ||
cmd=run.pl | ||
skip_scoring=false | ||
max_ngram_order=4 # Approximate the lattice-rescoring by limiting the max-ngram-order | ||
# if it's set, it merges histories in the lattice if they share | ||
# the same ngram history and this prevents the lattice from | ||
# exploding exponentially. Details of the n-gram approximation | ||
# method are described in section 2.3 of the paper | ||
# http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdm | ||
|
||
weight=0.5 # Interpolation weight for RNNLM. | ||
normalize=false # If true, we add a normalization step to the output of the RNNLM | ||
# so that it adds up to *exactly* 1. Note that this is not necessary | ||
# as in our RNNLM setup, a properly trained network would automatically | ||
# have its normalization term close to 1. The details of this | ||
# could be found at http://www.danielpovey.com/files/2018_icassp_rnnlm.pdf | ||
|
||
# End configuration section. | ||
|
||
echo "$0 $@" # Print the command line for logging | ||
|
||
. ./utils/parse_options.sh | ||
|
||
if [ $# != 5 ]; then | ||
echo "Does language model rescoring of lattices (remove old LM, add new LM)" | ||
echo "with Kaldi RNNLM trained on reversed text. See comments in file for details" | ||
echo "" | ||
echo "Usage: $0 [options] <old-lang-dir> <rnnlm-dir> \\" | ||
echo " <data-dir> <input-decode-dir> <output-decode-dir>" | ||
echo " e.g.: $0 data/lang_tg exp/rnnlm_lstm/ data/test \\" | ||
echo " exp/tri3/test_rnnlm_forward exp/tri3/test_rnnlm_bidirection" | ||
echo "options: [--cmd (run.pl|queue.pl [queue opts])]" | ||
exit 1; | ||
fi | ||
|
||
[ -f path.sh ] && . ./path.sh; | ||
|
||
oldlang=$1 | ||
rnnlm_dir=$2 | ||
data=$3 | ||
indir=$4 | ||
outdir=$5 | ||
|
||
oldlm=$oldlang/G.fst | ||
if [ ! -f $oldlm ]; then | ||
echo "$0: file $oldlm not found; using $oldlang/G.carpa" | ||
oldlm=$oldlang/G.carpa | ||
fi | ||
|
||
[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1; | ||
[ ! -f $rnnlm_dir/final.raw ] && echo "$0: Missing file $rnnlm_dir/final.raw" && exit 1; | ||
[ ! -f $rnnlm_dir/feat_embedding.final.mat ] && [ ! -f $rnnlm_dir/word_embedding.final.mat ] && echo "$0: Missing word embedding file" && exit 1; | ||
|
||
[ ! -f $oldlang/words.txt ] &&\ | ||
echo "$0: Missing file $oldlang/words.txt" && exit 1; | ||
! ls $indir/lat.*.gz >/dev/null &&\ | ||
echo "$0: No lattices input directory $indir" && exit 1; | ||
awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) { | ||
print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \ | ||
|| exit 1; | ||
|
||
normalize_opt= | ||
if $normalize; then | ||
normalize_opt="--normalize-probs=true" | ||
fi | ||
oldlm_command="fstproject --project_output=true $oldlm |" | ||
special_symbol_opts=$(cat $rnnlm_dir/special_symbol_opts.txt) | ||
|
||
word_embedding= | ||
if [ -f $rnnlm_dir/word_embedding.final.mat ]; then | ||
word_embedding=$rnnlm_dir/word_embedding.final.mat | ||
else | ||
word_embedding="'rnnlm-get-word-embedding $rnnlm_dir/word_feats.txt $rnnlm_dir/feat_embedding.final.mat -|'" | ||
fi | ||
|
||
mkdir -p $outdir/log | ||
nj=`cat $indir/num_jobs` || exit 1; | ||
cp $indir/num_jobs $outdir | ||
|
||
# In order to rescore with a backward RNNLM, we first remove the original LM | ||
# scores with lattice-lmrescore, before reversing the lattices | ||
oldlm_weight=$(perl -e "print -1.0 * $weight;") | ||
if [ "$oldlm" == "$oldlang/G.fst" ]; then | ||
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ | ||
lattice-lmrescore --lm-scale=$oldlm_weight \ | ||
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ | ||
lattice-reverse ark:- ark:- \| \ | ||
lattice-lmrescore-kaldi-rnnlm --lm-scale=$weight $special_symbol_opts \ | ||
--max-ngram-order=$max_ngram_order $normalize_opt \ | ||
$word_embedding "$rnnlm_dir/final.raw" ark:- ark:- \| \ | ||
lattice-reverse ark:- "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; | ||
else | ||
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ | ||
lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \ | ||
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm" ark:- \| \ | ||
lattice-reverse ark:- ark:- \| \ | ||
lattice-lmrescore-kaldi-rnnlm --lm-scale=$weight $special_symbol_opts \ | ||
--max-ngram-order=$max_ngram_order $normalize_opt \ | ||
$word_embedding "$rnnlm_dir/final.raw" ark:- ark:- \| \ | ||
lattice-reverse ark:- "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; | ||
fi | ||
|
||
if ! $skip_scoring ; then | ||
err_msg="$0: Not scoring because local/score.sh does not exist or not executable." | ||
[ ! -x local/score.sh ] && echo $err_msg && exit 1; | ||
echo local/score.sh --cmd "$cmd" $data $oldlang $outdir | ||
local/score.sh --cmd "$cmd" $data $oldlang $outdir | ||
else | ||
echo "$0: Not scoring because --skip-scoring was specified." | ||
fi | ||
|
||
exit 0; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// latbin/lattice-reverse.cc | ||
|
||
// Copyright 2018 Hainan Xu | ||
|
||
// See ../../COPYING for clarification regarding multiple authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED | ||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, | ||
// MERCHANTABLITY OR NON-INFRINGEMENT. | ||
// See the Apache 2 License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
|
||
#include "base/kaldi-common.h" | ||
#include "util/common-utils.h" | ||
#include "fstext/fstext-lib.h" | ||
#include "lat/kaldi-lattice.h" | ||
|
||
int main(int argc, char *argv[]) { | ||
try { | ||
using namespace kaldi; | ||
typedef kaldi::int32 int32; | ||
typedef kaldi::int64 int64; | ||
using fst::SymbolTable; | ||
using fst::VectorFst; | ||
using fst::StdArc; | ||
|
||
const char *usage = | ||
"Reverse a lattice in order to rescore the lattice with a RNNLM \n" | ||
"trained reversed text. An example for its application is at \n" | ||
"swbd/local/rnnlm/run_lstm_tdnn_back.sh\n" | ||
"Usage: lattice-reverse lattice-rspecifier lattice-wspecifier\n" | ||
" e.g.: lattice-reverse ark:forward.lats ark:backward.lats\n"; | ||
|
||
ParseOptions po(usage); | ||
std::string include_rxfilename; | ||
std::string exclude_rxfilename; | ||
|
||
po.Read(argc, argv); | ||
|
||
if (po.NumArgs() != 2) { | ||
po.PrintUsage(); | ||
exit(1); | ||
} | ||
|
||
std::string lats_rspecifier = po.GetArg(1), | ||
lats_wspecifier = po.GetArg(2); | ||
|
||
int32 n_done = 0; | ||
|
||
SequentialLatticeReader lattice_reader(lats_rspecifier); | ||
LatticeWriter lattice_writer(lats_wspecifier); | ||
|
||
for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++) { | ||
string key = lattice_reader.Key(); | ||
Lattice &lat = lattice_reader.Value(); | ||
Lattice olat; | ||
fst::Reverse(lat, &olat); | ||
lattice_writer.Write(lattice_reader.Key(), olat); | ||
} | ||
|
||
KALDI_LOG << "Done reversing " << n_done << " lattices."; | ||
|
||
return (n_done != 0 ? 0 : 1); | ||
} catch(const std::exception &e) { | ||
std::cerr << e.what(); | ||
return -1; | ||
} | ||
} |