Skip to content

Commit

Permalink
[egs] Fix LM/lexicon issues in IAM; Add unk decoding; Update results. (
Browse files Browse the repository at this point in the history
  • Loading branch information
hhadian authored and danpovey committed Mar 27, 2018
1 parent e5b6696 commit d7e8890
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 43 deletions.
15 changes: 15 additions & 0 deletions egs/iam/v1/local/chain/compare_wer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ if [ $# == 0 ]; then
echo "e.g.: $0 exp/chain/cnn{1a,1b}"
exit 1
fi
. ./path.sh

echo "# $0 $*"
used_epochs=false
Expand All @@ -26,6 +27,13 @@ for x in $*; do
done
echo

echo -n "# CER "
for x in $*; do
cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}')
printf "% 10s" $cer
done
echo

if $used_epochs; then
exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems.
fi
Expand Down Expand Up @@ -57,3 +65,10 @@ for x in $*; do
printf "% 10s" $prob
done
echo

echo -n "# Parameters "
for x in $*; do
params=$(nnet3-info $x/final.mdl 2>/dev/null | grep num-parameters | cut -d' ' -f2 | awk '{printf "%0.2fM\n",$1/1000000}')
printf "% 10s" $params
done
echo
14 changes: 10 additions & 4 deletions egs/iam/v1/local/chain/run_cnn_1a.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
# steps/info/chain_dir_info.pl exp/chain/cnn_1a/
# exp/chain/cnn_1a/: num-iters=21 nj=2..4 num-params=4.4M dim=40->364 combine=-0.021->-0.015 xent:train/valid[13,20,final]=(-1.05,-0.701,-0.591/-1.30,-1.08,-1.00) logprob:train/valid[13,20,final]=(-0.061,-0.034,-0.030/-0.107,-0.101,-0.098)

# cat exp/chain/cnn_1a/decode_test/scoring_kaldi/best_*
# %WER 5.94 [ 3913 / 65921, 645 ins, 1466 del, 1802 sub ] exp/chain/cnn_1a/decode_test//cer_11_0.0
# %WER 9.13 [ 1692 / 18542, 162 ins, 487 del, 1043 sub ] exp/chain/cnn_1a/decode_test/wer_11_0.0
# local/chain/compare_wer.sh exp/chain/cnn_1a/
# System cnn_1a
# WER 18.58
# CER 10.17
# Final train prob -0.0122
# Final valid prob -0.0999
# Final train prob (xent) -0.5652
# Final valid prob (xent) -0.9758
# Parameters 4.36M

set -e -o pipefail

Expand Down Expand Up @@ -40,7 +46,7 @@ tdnn_dim=450
# training options
srand=0
remove_egs=false
lang_test=lang_test
lang_test=lang_unk
# End configuration section.
echo "$0 $@" # Print the command line for logging

Expand Down
22 changes: 11 additions & 11 deletions egs/iam/v1/local/chain/run_cnn_chainali_1b.sh
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#!/bin/bash

# chainali_1b is as chainali_1a except it has 3 more cnn layers and 1 less tdnn layer.
# ./local/chain/compare_wer.sh exp/chain/cnn_chainali_1a/ exp/chain/cnn_chainali_1b/
# System cnn_chainali_1a cnn_chainali_1b
# WER 6.69 6.25
# Final train prob -0.0132 -0.0041
# Final valid prob -0.0509 -0.0337
# Final train prob (xent) -0.6393 -0.6287
# Final valid prob (xent) -1.0116 -0.9064

# local/chain/compare_wer.sh exp/chain/cnn_1a/ exp/chain/cnn_chainali_1b/
# System cnn_1a cnn_chainali_1b
# WER 18.58 14.67
# CER 10.17 7.31
# Final train prob -0.0122 0.0042
# Final valid prob -0.0999 -0.0256
# Final train prob (xent) -0.5652 -0.6282
# Final valid prob (xent) -0.9758 -0.9096
# Parameters 4.36M 3.96M

# steps/info/chain_dir_info.pl exp/chain/chainali_cnn_1b/
# exp/chain/chainali_cnn_1b/: num-iters=21 nj=2..4 num-params=4.0M dim=40->364 combine=-0.009->-0.005 xent:train/valid[13,20,final]=(-1.47,-0.728,-0.623/-1.69,-1.02,-0.940) logprob:train/valid[13,20,final]=(-0.068,-0.030,-0.011/-0.086,-0.056,-0.038)

# cat exp/chain/cnn_chainali_1b/decode_test/scoring_kaldi/best_*
# %WER 3.94 [ 2600 / 65921, 415 ins, 1285 del, 900 sub ] exp/chain/cnn_chainali_1b/decode_test/cer_10_0.0
# %WER 6.25 [ 1158 / 18542, 103 ins, 469 del, 586 sub ] exp/chain/cnn_chainali_1b/decode_test/wer_12_0.0

set -e -o pipefail

Expand Down Expand Up @@ -46,7 +46,7 @@ tdnn_dim=450
# training options
srand=0
remove_egs=false
lang_test=lang_test
lang_test=lang_unk
# End configuration section.
echo "$0 $@" # Print the command line for logging

Expand Down
28 changes: 13 additions & 15 deletions egs/iam/v1/local/prepare_dict.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,27 @@ cat data/train/text | \
perl -ne '@A = split; shift @A; for(@A) {print join("\n", split(//)), "\n";}' | \
sort -u > $dir/nonsilence_phones.txt

# Now list all the unique words (that use only the above letters)
# in data/train/text and LOB+Brown corpora with their comprising
# letters as their transcription. (Letter # is replaced with <HASH>)
# Now use the pocolm's wordlist which is the most N frequent words in
# in data/train/text and LOB+Brown corpora (dev and test excluded) with their comprising
# letters as their transcription. Only include words that use the above letters.
# (Letter # is replaced with <HASH>)

export letters=$(cat $dir/nonsilence_phones.txt | tr -d "\n")

cut -d' ' -f2- data/train/text | \
cat data/local/lobcorpus/0167/download/LOB_COCOA/lob.txt \
data/local/browncorpus/brown.txt - | \
cat data/local/local_lm/data/wordlist | \
perl -e '$letters=$ENV{letters};
while(<>){ @A = split;
foreach(@A) {
if(! $seen{$_} && $_ =~ m/^[$letters]+$/){
$seen{$_} = 1;
$trans = join(" ", split(//));
while(<>){
chop;
$w = $_;
if($w =~ m/^[$letters]+$/){
$trans = join(" ", split(//, $w));
$trans =~ s/#/<HASH>/g;
print "$_ $trans\n";
print "$w $trans\n";
}
}
}' | sort > $dir/lexicon.txt
}' | sort -u > $dir/lexicon.txt


sed -i '' "s/#/<HASH>/" $dir/nonsilence_phones.txt
sed -i "s/#/<HASH>/" $dir/nonsilence_phones.txt

echo '<sil> SIL' >> $dir/lexicon.txt
echo '<unk> SIL' >> $dir/lexicon.txt
Expand Down
117 changes: 117 additions & 0 deletions egs/iam/v1/local/remove_test_utterances_from_lob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
# Copyright 2018 Ashish Arora

import argparse
import os
import numpy as np
import sys
import re

parser = argparse.ArgumentParser(description="""Removes dev/test set lines
from the LOB corpus. Reads the
corpus from stdin, and writes it to stdout.""")
parser.add_argument('dev_text', type=str,
help='dev transcription location.')
parser.add_argument('test_text', type=str,
help='test transcription location.')
args = parser.parse_args()

def remove_punctuations(transcript):
char_list = []
for char in transcript:
if char.isdigit() or char == '+' or char == '~' or char == '?':
continue
if char == '#' or char == '=' or char == '-' or char == '!':
continue
if char == ',' or char == '.' or char == ')' or char == '\'':
continue
if char == '(' or char == ':' or char == ';' or char == '"':
continue
char_list.append(char)
return char_list


def remove_special_words(words):
word_list = []
for word in words:
if word == '<SIC>' or word == '#':
continue
word_list.append(word)
return word_list


# process and add dev/eval transcript in a list
# remove special words, punctuations, spaces between words
# lowercase the characters
def read_utterances(text_file_path):
with open(text_file_path, 'rt') as in_file:
for line in in_file:
words = line.strip().split()
words_wo_sw = remove_special_words(words)
transcript = ''.join(words_wo_sw[1:])
transcript = transcript.lower()
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
utterance_dict[words_wo_sw[0]] = transcript


### main ###

# read utterances and add it to utterance_dict
utterance_dict = dict()
read_utterances(args.dev_text)
read_utterances(args.test_text)

# read corpus and add it to below lists
corpus_text_lowercase_wo_sc = list()
corpus_text_wo_sc = list()
original_corpus_text = list()
for line in sys.stdin:
original_corpus_text.append(line)
words = line.strip().split()
words_wo_sw = remove_special_words(words)

transcript = ''.join(words_wo_sw)
transcript = transcript.lower()
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
corpus_text_lowercase_wo_sc.append(transcript)

transcript = ''.join(words_wo_sw)
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
corpus_text_wo_sc.append(transcript)

# find majority of utterances below
# for utterances which were not found
# add them to remaining_utterances
row_to_keep = [True for i in range(len(original_corpus_text))]
remaining_utterances = dict()
for line_id, line_to_find in utterance_dict.items():
found_line = False
for i in range(1, (len(corpus_text_lowercase_wo_sc) - 2)):
# Combine 3 consecutive lines of the corpus into a single line
prev_words = corpus_text_lowercase_wo_sc[i - 1].strip()
curr_words = corpus_text_lowercase_wo_sc[i].strip()
next_words = corpus_text_lowercase_wo_sc[i + 1].strip()
new_line = prev_words + curr_words + next_words
transcript = ''.join(new_line)
if line_to_find in transcript:
found_line = True
row_to_keep[i-1] = False
row_to_keep[i] = False
row_to_keep[i+1] = False
if not found_line:
remaining_utterances[line_id] = line_to_find


for i in range(len(original_corpus_text)):
transcript = original_corpus_text[i].strip()
if row_to_keep[i]:
print(transcript)

print('Sentences not removed from LOB: {}'.format(remaining_utterances), file=sys.stderr)
print('Total test+dev sentences: {}'.format(len(utterance_dict)), file=sys.stderr)
print('Number of sentences not removed from LOB: {}'. format(len(remaining_utterances)), file=sys.stderr)
print('LOB lines: Before: {} After: {}'.format(len(original_corpus_text),
row_to_keep.count(True)), file=sys.stderr)
Loading

0 comments on commit d7e8890

Please sign in to comment.