-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[egs] LT-LM recipe for librispeech (#4590)
Co-authored-by: Anton Mitrofanov <mitrofanov-aa@speechpro.com>
- Loading branch information
Showing
50 changed files
with
4,926 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# LT-LM: a novel non-autoregressive language model for single-shot lattice rescoring | ||
[Paper](https://arxiv.org/pdf/2104.02526.pdf) | ||
|
||
## Setup: | ||
`cd fairseq_ltlm && setup.sh` | ||
## run: | ||
* put slurm.conf to conf/ | ||
* modify fairseq\_ltlm/recipes/config.sh if needed | ||
* `bash fairseq\_ltlm/recipes/run.sh` | ||
## evaluate: | ||
For evaluation, you can | ||
run fairseq\_ltlm/recipes/run\_5\_eval.sh (see run.sh) or use fairseq\_ltlm/ltlm/eval.py directly. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
kaldi=../../../../.. | ||
all: | ||
|
||
include $(kaldi)/src/kaldi.mk | ||
|
||
EXTRA_CXXFLAGS += -Wno-sign-compare | ||
EXTRA_CXXFLAGS += -I$(kaldi)/src | ||
BINFILES = latgen-faster-mapped-fake-am | ||
|
||
OBJFILES = | ||
|
||
TESTFILES = | ||
|
||
ADDLIBS = $(kaldi)/src/decoder/kaldi-decoder.a $(kaldi)/src/lat/kaldi-lat.a $(kaldi)/src/lm/kaldi-lm.a \ | ||
$(kaldi)/src/fstext/kaldi-fstext.a $(kaldi)/src/hmm/kaldi-hmm.a \ | ||
$(kaldi)/src/transform/kaldi-transform.a $(kaldi)/src/gmm/kaldi-gmm.a \ | ||
$(kaldi)/src/tree/kaldi-tree.a $(kaldi)/src/util/kaldi-util.a $(kaldi)/src/matrix/kaldi-matrix.a \ | ||
$(kaldi)/src/base/kaldi-base.a | ||
|
||
|
||
include $(kaldi)/src/makefiles/default_rules.mk | ||
|
213 changes: 213 additions & 0 deletions
213
egs/librispeech/s5/fairseq_ltlm/kaldi_utils/latgen-faster-mapped-fake-am.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
|
||
// Copyright (c) 2021, Speech Technology Center Ltd. All rights reserved. | ||
// Anton Mitrofanov | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// It is latgen-faster-mapped adopted to fake lattice generation | ||
|
||
#include <chrono> | ||
|
||
#include "base/kaldi-common.h" | ||
#include "base/timer.h" | ||
#include "decoder/decodable-matrix.h" | ||
#include "decoder/decoder-wrappers.h" | ||
#include "fstext/fstext-lib.h" | ||
#include "hmm/transition-model.h" | ||
#include "tree/context-dep.h" | ||
#include "util/common-utils.h" | ||
|
||
using namespace kaldi; | ||
typedef kaldi::int32 int32; | ||
using fst::Fst; | ||
using fst::StdArc; | ||
using fst::SymbolTable; | ||
|
||
int main(int argc, char *argv[]) { | ||
try { | ||
const char *usage = | ||
"Generate lattices, reading emulating am as matrices\n" | ||
" (model is needed only for the integer mappings in its " | ||
"transition-model)\n" | ||
"Usage: latgen-faster-mapped-fake-am [options] trans-model-in fst-in " | ||
"fam-rspecifier ali_rspecifier" | ||
" lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; | ||
ParseOptions po(usage); | ||
Timer timer; | ||
bool allow_partial = false; | ||
BaseFloat acoustic_scale = 0.1; | ||
LatticeFasterDecoderConfig config; | ||
|
||
std::string word_syms_filename; | ||
config.Register(&po); | ||
po.Register("acoustic-scale", &acoustic_scale, | ||
"Scaling factor for acoustic likelihoods"); | ||
|
||
po.Register("word-symbol-table", &word_syms_filename, | ||
"Symbol table for words [for debug output]"); | ||
po.Register("allow-partial", &allow_partial, | ||
"If true, produce output even if end state was not reached."); | ||
|
||
po.Read(argc, argv); | ||
|
||
if (po.NumArgs() < 5 || po.NumArgs() > 7) { | ||
po.PrintUsage(); | ||
exit(1); | ||
} | ||
|
||
std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), | ||
fam_rspecifier = po.GetArg(3), ali_rspecifier = po.GetArg(4), | ||
lattice_wspecifier = po.GetArg(5), | ||
words_wspecifier = po.GetOptArg(6), | ||
alignment_wspecifier = po.GetOptArg(7); | ||
|
||
TransitionModel trans_model; | ||
ReadKaldiObject(model_in_filename, &trans_model); | ||
|
||
bool determinize = config.determinize_lattice; | ||
CompactLatticeWriter compact_lattice_writer; | ||
LatticeWriter lattice_writer; | ||
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) | ||
: lattice_writer.Open(lattice_wspecifier))) | ||
KALDI_ERR << "Could not open table for writing lattices: " | ||
<< lattice_wspecifier; | ||
|
||
Int32VectorWriter words_writer(words_wspecifier); | ||
|
||
Int32VectorWriter alignment_writer(alignment_wspecifier); | ||
|
||
fst::SymbolTable *word_syms = NULL; | ||
if (word_syms_filename != "") | ||
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) | ||
KALDI_ERR << "Could not read symbol table from file " | ||
<< word_syms_filename; | ||
|
||
double tot_like = 0.0; | ||
kaldi::int64 frame_count = 0; | ||
int num_success = 0, num_fail = 0; | ||
|
||
// Reading Fake acoustic model form ark file | ||
KALDI_LOG << "Loading Fake Acoustic Model"; | ||
SequentialBaseFloatMatrixReader fam_model_read(fam_rspecifier); | ||
std::string fam_model_key = fam_model_read.Key(); | ||
Matrix<BaseFloat> fam_model(fam_model_read.Value()); | ||
KALDI_LOG << "Apply log."; | ||
fam_model.ApplyLog(); | ||
|
||
if (fam_model_key != "fam_model") { | ||
KALDI_ERR << fam_rspecifier << " - Wrong fam_model."; | ||
po.PrintUsage(); | ||
exit(1); | ||
} | ||
KALDI_LOG << "Fake Acoustic is loaded. Shape is (" << fam_model.NumRows() | ||
<< ", " << fam_model.NumCols() << ")"; | ||
|
||
SequentialInt32VectorReader ali_reader(ali_rspecifier); | ||
if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { | ||
// SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); | ||
// Input FST is just one FST, not a table of FSTs. | ||
Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); | ||
timer.Reset(); | ||
|
||
{ | ||
LatticeFasterDecoder decoder(*decode_fst, config); | ||
for (; !ali_reader.Done(); ali_reader.Next()) { | ||
std::string utt = ali_reader.Key(); | ||
std::vector<int32> ali(ali_reader.Value()); | ||
KALDI_LOG << "Process " << utt << ". " << ali.size() << " frames"; | ||
ali_reader.FreeCurrent(); | ||
if (ali.size() == 0) { | ||
KALDI_WARN << "Zero-length utterance: " << utt; | ||
num_fail++; | ||
continue; | ||
} | ||
// Inference fake AM | ||
kaldi::Matrix<BaseFloat> loglikes(ali.size(), fam_model.NumRows()); | ||
loglikes.SetZero(); | ||
for (int i = 0; i < ali.size(); i++) { | ||
int32 pdf_id = ali[i]; | ||
loglikes.CopyRowFromVec(fam_model.Row(pdf_id), i); | ||
SubVector<BaseFloat> row(loglikes, i); | ||
} | ||
// end | ||
DecodableMatrixScaledMapped decodable(trans_model, loglikes, | ||
acoustic_scale); | ||
|
||
double like; | ||
if (DecodeUtteranceLatticeFaster( | ||
decoder, decodable, trans_model, word_syms, utt, | ||
acoustic_scale, determinize, allow_partial, &alignment_writer, | ||
&words_writer, &compact_lattice_writer, &lattice_writer, | ||
&like)) { | ||
tot_like += like; | ||
frame_count += loglikes.NumRows(); | ||
num_success++; | ||
} else | ||
num_fail++; | ||
} | ||
} | ||
delete decode_fst; // delete this only after decoder goes out of scope. | ||
} else { // We have different FSTs for different utterances. | ||
KALDI_LOG << "FSTs not implemented yet."; | ||
exit(1); | ||
// SequentialTableReader<fst::VectorFstHolder> | ||
// fst_reader(fst_in_str); RandomAccessBaseFloatMatrixReader | ||
// loglike_reader(feature_rspecifier); for (; !fst_reader.Done(); | ||
// fst_reader.Next()) { | ||
// std::string utt = fst_reader.Key(); | ||
// if (!loglike_reader.HasKey(utt)) { | ||
// KALDI_WARN << "Not decoding utterance " << utt | ||
// << " because no loglikes available."; | ||
// num_fail++; | ||
// continue; | ||
// } | ||
// const Matrix<BaseFloat> &loglikes = loglike_reader.Value(utt); | ||
// if (loglikes.NumRows() == 0) { | ||
// KALDI_WARN << "Zero-length utterance: " << utt; | ||
// num_fail++; | ||
// continue; | ||
// } | ||
// LatticeFasterDecoder decoder(fst_reader.Value(), config); | ||
// DecodableMatrixScaledMapped decodable(trans_model, loglikes, | ||
// acoustic_scale); double like; if (DecodeUtteranceLatticeFaster( | ||
// decoder, decodable, trans_model, word_syms, utt, | ||
// acoustic_scale, determinize, allow_partial, | ||
// &alignment_writer, &words_writer, | ||
// &compact_lattice_writer, &lattice_writer, &like)) { | ||
// tot_like += like; | ||
// frame_count += loglikes.NumRows(); | ||
// num_success++; | ||
// } else num_fail++; | ||
// } | ||
} | ||
|
||
double elapsed = timer.Elapsed(); | ||
KALDI_LOG << "Time taken " << elapsed | ||
<< "s: real-time factor assuming 100 frames/sec is " | ||
<< (elapsed * 100.0 / frame_count); | ||
KALDI_LOG << "Done " << num_success << " utterances, failed for " | ||
<< num_fail; | ||
KALDI_LOG << "Overall log-likelihood per frame is " | ||
<< (tot_like / frame_count) << " over " << frame_count | ||
<< " frames."; | ||
|
||
delete word_syms; | ||
if (num_success != 0) | ||
return 0; | ||
else | ||
return 1; | ||
} catch (const std::exception &e) { | ||
std::cerr << e.what(); | ||
return -1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright 2021 STC-Innovation LTD (Author: Anton Mitrofanov) | ||
import argparse | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class WordTokenizer: | ||
""" A words.txt mapping""" | ||
@staticmethod | ||
def add_args(parser: argparse.ArgumentParser): | ||
parser.add_argument('--tokenizer_fn', type=str, required=True, | ||
help='Tokenizer fname( words.txt)') | ||
#parser.add_argument('--unk', default='<UNK>', help="Unk word") # fairseq bug | ||
|
||
@staticmethod | ||
def build_from_args(args): | ||
kwargs = {"fname": args.tokenizer_fn} | ||
#'unk': args.unk_word} | ||
|
||
return WordTokenizer(**kwargs) | ||
|
||
def __init__(self, fname, unk_word='<UNK>'): | ||
logger.info(f'Loading WordTokenizer {fname}') | ||
with open(fname, 'r', encoding='utf-8') as f: | ||
self.word2id = {w: int(i) for w, i in map(str.split, f.readlines())} | ||
self.id2word = ['']*(max(self.word2id.values()) + 1) | ||
self.unk=unk_word | ||
if self.unk not in self.word2id: | ||
if self.unk.lower() in self.word2id: | ||
self.unk=self.unk.lower() | ||
else: | ||
raise f"unk word {unk_word} not in {fname}" | ||
for w, i in self.word2id.items(): | ||
self.id2word[i] = w | ||
assert self.word2id['<eps>'] == 0 and \ | ||
'<s>' in self.word2id.keys() and \ | ||
'</s>' in self.word2id.keys(), RuntimeError("<esp>!=0") | ||
|
||
self.real_words_ids = [i for w, i in self.word2id.items() \ | ||
if w.find('<') == w.find('>') == w.find('#') == w.find('!') == w.find('[') == w.find(']') == -1 and \ | ||
not w.endswith('-') and not w.startswith("-") ] | ||
|
||
logger.info(f'WordTokenizer {fname} loaded. Vocab size {len(self)}.') | ||
self.disambig_word_ids = [i for w, i in self.word2id.items() \ | ||
if (w != "<s>" and w != "</s>") and ( | ||
w.find('<') != -1 or | ||
w.find('>') != -1 or | ||
w.find('#') != -1 or | ||
w.find('!') != -1 or | ||
w.find('[') != -1 or | ||
w.find(']') != -1 or | ||
w.endswith('-') or | ||
w.startswith('-'))] | ||
logger.info(f"WordTokenizer Disambig ids: {self.disambig_word_ids}") | ||
logger.info(f"WordTokenizer Disambig words: {[ self.id2word[i] for i in self.disambig_word_ids]}") | ||
|
||
def encode(self, text, bos=False, eos=False): | ||
return [ | ||
([self.get_bos_id()] if bos else []) + | ||
[self.word2id[w] if w in self.word2id.keys() else self.word2id[self.unk] for w in line.split()] + | ||
([self.get_eos_id()] if eos else []) for line in text] | ||
|
||
def decode(self, text_ids): | ||
return [[self.id2word[i] for i in line_ids] for line_ids in text_ids] | ||
|
||
def __len__(self): | ||
return len(self.id2word) | ||
|
||
def get_real_words_ids(self): | ||
return self.real_words_ids | ||
|
||
def get_disambig_words_ids(self): | ||
return self.disambig_word_ids | ||
|
||
def get_bos_id(self): | ||
return self.word2id["<s>"] | ||
|
||
def get_eos_id(self): | ||
return self.word2id["</s>"] | ||
|
||
def get_unk_id(self): | ||
return self.word2id[self.unk] | ||
|
||
def pad(self): | ||
return self.word2id['<eps>'] | ||
|
||
def print_lat(self, lat, print_word_id=False, p=None): | ||
for i, arc in enumerate(lat): | ||
out_str = f"{arc[1]} {arc[2]} {arc[0] if print_word_id else self.id2word[arc[0]]}" | ||
if p is not None: | ||
out_str += f" {p[i]}" | ||
print(out_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2021 STC-Innovation LTD (Author: Anton Mitrofanov) | ||
from .criterions.bce_loss import BCECriterion | ||
from .datasets import LatsOracleAlignDataSet | ||
from .models import LTLM | ||
from .tasks.rescoring_task import RescoringTask |
Oops, something went wrong.