From 0d839b0bf66ba09b930328b2fb313343c4d8b18e Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Mon, 18 Sep 2017 16:32:33 -0400 Subject: [PATCH 01/23] draft --- src/latbin/lattice-lmrescore-nnet3-rnnlm.cc | 147 ++++++++++++++ src/rnnlm/Makefile | 2 +- .../kaldi-rnnlm-decodable-simple-looped.cc | 182 +++++++++++++++++ .../kaldi-rnnlm-decodable-simple-looped.h | 187 ++++++++++++++++++ 4 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 src/latbin/lattice-lmrescore-nnet3-rnnlm.cc create mode 100644 src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc create mode 100644 src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h diff --git a/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc b/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc new file mode 100644 index 00000000000..26754d26629 --- /dev/null +++ b/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc @@ -0,0 +1,147 @@ +// latbin/lattice-lmrescore-nnet3-rnnlm.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "fstext/fstext-lib.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" +#include "rnnlm/kaldi-rnnlm-rescoring.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Rescores lattice with rnnlm. The LM will be wrapped into the\n" + "DeterministicOnDemandFst interface and the rescoring is done by\n" + "composing with the wrapped LM using a special type of composition\n" + "algorithm. Determinization will be applied on the composed lattice.\n" + "\n" + "Usage: lattice-lmrescore-nnet3-rnnlm [options] \\\n" + " \\\n" + " \n" + " e.g.: lattice-lmrescore-nnet3-rnnlm --lm-scale=-1.0 words.txt \\\n" + " ark:in.lats rnnlm ark:out.lats\n"; + + ParseOptions po(usage); + int32 max_ngram_order = 3; + BaseFloat lm_scale = 1.0; + + po.Register("lm-scale", &lm_scale, "Scaling factor for language model " + "costs; frequently 1.0 or -1.0"); + po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the " + "rnnlm context to the given number, -1 means we are not going " + "to limit it."); + + po.Read(argc, argv); + + if (po.NumArgs() != 4 && po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string lats_rspecifier, rnn_wordlist, + word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier; + KALDI_ASSERT (po.NumArgs() == 5); + + rnn_wordlist = po.GetArg(1); + word_symbols_rxfilename = po.GetArg(2); + lats_rspecifier = po.GetArg(3); + rnnlm_rxfilename = po.GetArg(4); + lats_wspecifier = po.GetArg(5); + + // Reads the language model. + kaldi::nnet3::Nnet rnnlm; + ReadKaldiObject(rnnlm_rxfilename, &rnnlm); + + if (!IsSimpleNnet(rnnlm)) + KALDI_ERR << "Input RNNLM in " << rnnlm_rxfilename + << " is not the type of neural net we were looking for; " + "failed IsSimpleNnet()."; + + CuMatrix word_embedding_mat; + ReadKaldiObject(word_embedding_rxfilename, &word_embedding_mat); + + const nnet3::DecodableRnnlmSimpleLoopedComputationOptions opts; + const nnet3::DecodableRnnlmSimpleLoopedInfo info(opts, rnnlm, word_embedding_mat); + + // Reads and writes as compact lattice. + SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); + CompactLatticeWriter compact_lattice_writer(lats_wspecifier); + + int32 n_done = 0, n_fail = 0; + for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { + std::string key = compact_lattice_reader.Key(); + CompactLattice clat = compact_lattice_reader.Value(); + compact_lattice_reader.FreeCurrent(); + + if (lm_scale != 0.0) { + // Before composing with the LM FST, we scale the lattice weights + // by the inverse of "lm_scale". We'll later scale by "lm_scale". + // We do it this way so we can determinize and it will give the + // right effect (taking the "best path" through the LM) regardless + // of the sign of lm_scale. + fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat); + ArcSort(&clat, fst::OLabelCompare()); + + // Wraps the rnnlm into FST. We re-create it for each lattice to prevent + // memory usage increasing with time. + nnet3::KaldiRnnlmDeterministicFst rnnlm_fst(max_ngram_order, + rnn_wordlist, + word_symbols_rxfilename, + info); + + // Composes lattice with language model. + CompactLattice composed_clat; + ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat); + + // Determinizes the composed lattice. + Lattice composed_lat; + ConvertLattice(composed_clat, &composed_lat); + Invert(&composed_lat); + CompactLattice determinized_clat; + DeterminizeLattice(composed_lat, &determinized_clat); + fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat); + if (determinized_clat.Start() == fst::kNoStateId) { + KALDI_WARN << "Empty lattice for utterance " << key + << " (incompatible LM?)"; + n_fail++; + } else { + compact_lattice_writer.Write(key, determinized_clat); + n_done++; + } + } else { + // Zero scale so nothing to do. + n_done++; + compact_lattice_writer.Write(key, clat); + } + } + + KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail; + return (n_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/rnnlm/Makefile b/src/rnnlm/Makefile index ac1ca92f8b3..1e57d2f77ad 100644 --- a/src/rnnlm/Makefile +++ b/src/rnnlm/Makefile @@ -10,7 +10,7 @@ TESTFILES = sampler-test sampling-lm-test rnnlm-example-test OBJFILES = sampler.o rnnlm-example.o rnnlm-example-utils.o \ rnnlm-core-training.o rnnlm-embedding-training.o rnnlm-core-compute.o \ rnnlm-utils.o rnnlm-training.o rnnlm-test-utils.o sampling-lm-estimate.o \ - sampling-lm.o + sampling-lm.o kaldi-rnnlm-decodable-simple-looped.o LIBNAME = kaldi-rnnlm diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc new file mode 100644 index 00000000000..ba298e417d3 --- /dev/null +++ b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc @@ -0,0 +1,182 @@ +// rnnlm/kaldi-rnnlm-decodable-simple-looped.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "rnnlm/kaldi-rnnlm-decodable-simple-looped.h" +#include "nnet3/nnet-utils.h" +#include "nnet3/nnet-compile-looped.h" + +namespace kaldi { +namespace nnet3 { + + +DecodableRnnlmSimpleLoopedInfo::DecodableRnnlmSimpleLoopedInfo( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const kaldi::nnet3::Nnet &rnnlm, + const CuMatrix &word_embedding_mat): + opts(opts), rnnlm(rnnlm), word_embedding_mat(word_embedding_mat) { + Init(opts, rnnlm, word_embedding_mat); +} + +void DecodableRnnlmSimpleLoopedInfo::Init( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const kaldi::nnet3::Nnet &rnnlm, + const CuMatrix &word_embedding_mat) { + opts.Check(); + KALDI_ASSERT(IsSimpleNnet(rnnlm)); + int32 left_context, right_context; + ComputeSimpleNnetContext(rnnlm, &left_context, &right_context); + frames_left_context = opts.extra_left_context_initial + left_context; + frames_right_context = right_context; + int32 frame_subsampling_factor = 1; + frames_per_chunk = GetChunkSize(rnnlm, frame_subsampling_factor, + opts.frames_per_chunk); + KALDI_ASSERT(frames_per_chunk == opts.frames_per_chunk); + nnet_output_dim = rnnlm.OutputDim("output"); + KALDI_ASSERT(nnet_output_dim > 0); + + int32 ivector_period = frames_per_chunk; + int32 extra_right_context = 0; + int32 num_sequences = 1; // we're processing one utterance at a time. + CreateLoopedComputationRequestSimple(rnnlm, frames_per_chunk, + frame_subsampling_factor, + ivector_period, + opts.extra_left_context_initial, + extra_right_context, + num_sequences, + &request1, &request2, &request3); + + CompileLooped(rnnlm, opts.optimize_config, request1, request2, + request3, &computation); + computation.ComputeCudaIndexes(); + if (GetVerboseLevel() >= 3) { + KALDI_VLOG(3) << "Computation is:"; + computation.Print(std::cerr, rnnlm); + } +} + +DecodableRnnlmSimpleLooped::DecodableRnnlmSimpleLooped( + const DecodableRnnlmSimpleLoopedInfo &info) : + info_(info), + computer_(info_.opts.compute_config, info_.computation, + info_.rnnlm, NULL), // NULL is 'nnet_to_update' + // since everytime we provide one chunk to the decodable object, the size of + // feats_ == frames_per_chunk + feats_(info_.frames_per_chunk, + info_.word_embedding_mat.NumRows()), // or Cols()? TODO(hxu) + current_log_post_offset_(-1) +{ + num_frames_ = feats_.NumRows(); +} + +void DecodableRnnlmSimpleLooped::TakeFeatures( + const std::vector &word_indexes) { + KALDI_ASSERT(word_indexes.size() == num_frames_); + std::vector > > + pairs(word_indexes.size()); + for (int32 i = 0; i < word_indexes.size(); i++) { + std::pair one_hot_index(word_indexes[i], 1.0); + std::vector > row(1, one_hot_index); + pairs[i] = row; + } + SparseMatrix feats_temp(feats_.NumCols(), pairs); + feats_.Swap(&feats_temp); + // resets offset so that AdvanceChunk() would be called in GetOutput() and + // GetNnetOutputForFrame() after taking new features + current_log_post_offset_ = -1; +} + +void DecodableRnnlmSimpleLooped::GetNnetOutputForFrame( + int32 frame, VectorBase *output) { + KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows()); + if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows()) + AdvanceChunk(); + output->CopyFromVec(current_nnet_output_.Row(frame - + current_log_post_offset_)); +} + +BaseFloat DecodableRnnlmSimpleLooped::GetOutput(int32 frame, int32 word_index) { + KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows()); + if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows()) + AdvanceChunk(); + +// int32 embedding_dim = info_.word_embedding_mat.NumCols(); +// int32 num_words = info_.word_embedding_mat.NumRows(); + + const CuMatrix &word_embedding_mat = info_.word_embedding_mat; + + CuMatrix current_nnet_output_gpu; + current_nnet_output_gpu.Swap(¤t_nnet_output_); + const CuSubVector hidden(current_nnet_output_gpu, + frame - current_log_post_offset_); + BaseFloat log_prob = + VecVec(hidden, word_embedding_mat.Row(word_index)); +// output_layer->ComputeLogprobOfWordGivenHistory(hidden, word_index); + // swap the pointer back so that this function can be called multiple times + // with the same returned value before taking next new feats + current_nnet_output_.Swap(¤t_nnet_output_gpu); + return log_prob; +} + +void DecodableRnnlmSimpleLooped::AdvanceChunk() { + int32 begin_input_frame, end_input_frame; + begin_input_frame = -info_.frames_left_context; + // note: end is last plus one. + end_input_frame = info_.frames_per_chunk + info_.frames_right_context; + // currently there is no left/right context and frames_per_chunk == 1 + KALDI_ASSERT(begin_input_frame == 0 && end_input_frame == 1); + + SparseMatrix feats_chunk(end_input_frame - begin_input_frame, + feats_.NumCols()); + int32 num_features = feats_.NumRows(); + for (int32 r = begin_input_frame; r < end_input_frame; r++) { + int32 input_frame = r; + if (input_frame < 0) input_frame = 0; + if (input_frame >= num_features) input_frame = num_features - 1; + feats_chunk.SetRow(r - begin_input_frame, feats_.Row(input_frame)); + } + +// const rnnlm::LmInputComponent* input_layer = info_.lm_nnet.InputLayer(); +// CuMatrix new_input(feats_chunk.NumRows(), input_layer->OutputDim()); +// input_layer->Propagate(feats_chunk, &new_input); + + CuMatrix input_embeddings(1, info_.word_embedding_mat.NumRows(), kUndefined); + input_embeddings.Row(0).CopyFromVec(info_.word_embedding_mat.Row(feats_chunk.Row(0).Sum())); + computer_.AcceptInput("input", &input_embeddings); + + computer_.Run(); + + { + // Note: here GetOutput() is used instead of GetOutputDestructive(), since + // here we have recurrence that goes directly from the output, and the call + // to GetOutputDestructive() would cause a crash on the next chunk. + CuMatrix output(computer_.GetOutput("output")); + + current_nnet_output_.Resize(0, 0); + current_nnet_output_.Swap(&output); + } + KALDI_ASSERT(current_nnet_output_.NumRows() == info_.frames_per_chunk && + current_nnet_output_.NumCols() == info_.nnet_output_dim); + + current_log_post_offset_ = 0; +} + + +} // namespace nnet3 +} // namespace kaldi diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h new file mode 100644 index 00000000000..40259999a17 --- /dev/null +++ b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h @@ -0,0 +1,187 @@ +// rnnlm/kaldi-rnnlm-decodable-simple-looped.h + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ +#define KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ + +#include +#include "base/kaldi-common.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "itf/decodable-itf.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/am-nnet-simple.h" +#include "rnnlm/rnnlm-core-compute.h" + +namespace kaldi { +namespace nnet3 { + +// See also nnet-am-decodable-simple.h, which is a decodable object that's based +// on breaking up the input into fixed chunks. The decodable object defined here is based on +// 'looped' computations, which naturally handle infinite left-context (but are +// only ideal for systems that have only recurrence in the forward direction, +// i.e. not BLSTMs... because there isn't a natural way to enforce extra right +// context for each chunk.) + + +// Note: the 'simple' in the name means it applies to networks for which +// IsSimpleNnet(nnet) would return true. 'looped' means we use looped +// computations, with a kGotoLabel statement at the end of it. +struct DecodableRnnlmSimpleLoopedComputationOptions { + int32 extra_left_context_initial; + int32 frames_per_chunk; + bool debug_computation; + NnetOptimizeOptions optimize_config; + NnetComputeOptions compute_config; + DecodableRnnlmSimpleLoopedComputationOptions(): + extra_left_context_initial(0), + frames_per_chunk(1), + debug_computation(false) { } + + void Check() const { + KALDI_ASSERT(extra_left_context_initial >= 0 && frames_per_chunk > 0); + } + + void Register(OptionsItf *opts) { + opts->Register("extra-left-context-initial", &extra_left_context_initial, + "Extra left context to use at the first frame of an utterance (note: " + "this will just consist of repeats of the first frame, and should not " + "usually be necessary."); + opts->Register("frames-per-chunk", &frames_per_chunk, + "Number of frames in each chunk that is separately evaluated " + "by the neural net."); + opts->Register("debug-computation", &debug_computation, "If true, turn on " + "debug for the actual computation (very verbose!)"); + + // register the optimization options with the prefix "optimization". + ParseOptions optimization_opts("optimization", opts); + optimize_config.Register(&optimization_opts); + + // register the compute options with the prefix "computation". + ParseOptions compute_opts("computation", opts); + compute_config.Register(&compute_opts); + } +}; + + +/** + When you instantiate class DecodableNnetSimpleLooped, you should give it + a const reference to this class, that has been previously initialized. + */ +class DecodableRnnlmSimpleLoopedInfo { + public: + DecodableRnnlmSimpleLoopedInfo( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const kaldi::nnet3::Nnet &rnnlm, + const CuMatrix &word_embedding_mat); + + void Init(const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const kaldi::nnet3::Nnet &rnnlm, + const CuMatrix &word_embedding_mat); + + const DecodableRnnlmSimpleLoopedComputationOptions &opts; + + const kaldi::nnet3::Nnet &rnnlm; + const CuMatrix &word_embedding_mat; + + // frames_left_context equals the model left context plus the value of the + // --extra-left-context-initial option. + int32 frames_left_context; + // frames_right_context is the same as the right-context of the model. + int32 frames_right_context; + // The frames_per_chunk equals the number of input frames we need for each + // chunk (except for the first chunk). + int32 frames_per_chunk; + + // The output dimension of the nnet neural network (not the final output). + int32 nnet_output_dim; + + // The 3 computation requests that are used to create the looped + // computation are stored in the class, as we need them to work out + // exactly shich iVectors are needed. + ComputationRequest request1, request2, request3; + + // The compiled, 'looped' computation. + NnetComputation computation; +}; + +/* + This class handles the neural net computation; it's mostly accessed + via other wrapper classes. + + It accept just input features */ +class DecodableRnnlmSimpleLooped { + public: + /** + This constructor takes features as input. + Note: it stores references to all arguments to the constructor, so don't + delete them till this goes out of scope. + + @param [in] info This helper class contains all the static pre-computed information + this class needs, and contains a pointer to the neural net. + @param [in] feats The input feature matrix. + */ + DecodableRnnlmSimpleLooped(const DecodableRnnlmSimpleLoopedInfo &info); + + // returns the number of frames of likelihoods. The same as feats_.NumRows() + inline int32 NumFrames() const { return num_frames_; } + + inline int32 NnetOutputDim() const { return info_.nnet_output_dim; } + + // Gets the nnet's output for a particular frame, with 0 <= frame < NumFrames(). + // 'output' must be correctly sized (with dimension NnetOutputDim()). Note: + // you're expected to call this, and GetOutput(), in an order of increasing + // frames. If you deviate from this, one of these calls may crash. + void GetNnetOutputForFrame(int32 frame, VectorBase *output); + + // Updates feats_ with the new incoming word specified in word_indexes + void TakeFeatures(const std::vector &word_indexes); + + // Gets the output for a particular frame and word_index, with + // 0 <= frame < NumFrames(). + BaseFloat GetOutput(int32 frame, int32 word_index); + + private: + // This function does the computation for the next chunk. + void AdvanceChunk(); + + const DecodableRnnlmSimpleLoopedInfo &info_; + + NnetComputer computer_; + + SparseMatrix feats_; + + int32 num_frames_; + + // The current nnet's output that we got from the last time we + // ran the computation. + Matrix current_nnet_output_; + + // The time-offset of the current log-posteriors, equals + // -1 when initialized, or 0 once AdvanceChunk() was called + int32 current_log_post_offset_; +}; + + +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ From 699c9566ce7c70eeb588be4a30418ed1998cb3a5 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 20 Sep 2017 17:18:17 -0400 Subject: [PATCH 02/23] lattice-rescoring draft finished --- src/latbin/Makefile | 13 +- ...lm.cc => lattice-lmrescore-kaldi-rnnlm.cc} | 10 +- src/rnnlm/Makefile | 2 +- ...ed.cc => rnnlm-decodable-simple-looped.cc} | 4 +- ...oped.h => rnnlm-decodable-simple-looped.h} | 0 src/rnnlm/rnnlm-lattice-rescoring.cc | 161 ++++++++++++++++++ src/rnnlm/rnnlm-lattice-rescoring.h | 88 ++++++++++ 7 files changed, 267 insertions(+), 11 deletions(-) rename src/latbin/{lattice-lmrescore-nnet3-rnnlm.cc => lattice-lmrescore-kaldi-rnnlm.cc} (95%) rename src/rnnlm/{kaldi-rnnlm-decodable-simple-looped.cc => rnnlm-decodable-simple-looped.cc} (98%) rename src/rnnlm/{kaldi-rnnlm-decodable-simple-looped.h => rnnlm-decodable-simple-looped.h} (100%) create mode 100644 src/rnnlm/rnnlm-lattice-rescoring.cc create mode 100644 src/rnnlm/rnnlm-lattice-rescoring.h diff --git a/src/latbin/Makefile b/src/latbin/Makefile index 43210c0d3e0..2a21d084f1e 100644 --- a/src/latbin/Makefile +++ b/src/latbin/Makefile @@ -4,6 +4,9 @@ all: EXTRA_CXXFLAGS = -Wno-sign-compare include ../kaldi.mk +LDFLAGS += $(CUDA_LDFLAGS) +LDLIBS += $(CUDA_LDLIBS) + BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \ lattice-lmrescore lattice-scale lattice-union lattice-to-post \ lattice-determinize lattice-oracle lattice-rmali \ @@ -21,17 +24,19 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \ lattice-confidence lattice-determinize-phone-pruned \ lattice-determinize-phone-pruned-parallel lattice-expand-ngram \ lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons \ - lattice-arc-post lattice-determinize-non-compact + lattice-arc-post lattice-determinize-non-compact lattice-lmrescore-kaldi-rnnlm OBJFILES = +cuda-compiled.o: ../kaldi.mk TESTFILES = -ADDLIBS = ../lat/kaldi-lat.a ../lm/kaldi-lm.a ../fstext/kaldi-fstext.a \ - ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ +ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../lat/kaldi-lat.a ../nnet3/kaldi-nnet3.a ../lm/kaldi-lm.a \ + ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a \ + ../cudamatrix/kaldi-cudamatrix.a ../matrix/kaldi-matrix.a \ ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc b/src/latbin/lattice-lmrescore-kaldi-rnnlm.cc similarity index 95% rename from src/latbin/lattice-lmrescore-nnet3-rnnlm.cc rename to src/latbin/lattice-lmrescore-kaldi-rnnlm.cc index 26754d26629..0ff8789608e 100644 --- a/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc +++ b/src/latbin/lattice-lmrescore-kaldi-rnnlm.cc @@ -23,8 +23,9 @@ #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" #include "lat/lattice-functions.h" -#include "rnnlm/kaldi-rnnlm-rescoring.h" +#include "rnnlm/rnnlm-lattice-rescoring.h" #include "util/common-utils.h" +#include "nnet3/nnet-utils.h" int main(int argc, char *argv[]) { try { @@ -61,15 +62,16 @@ int main(int argc, char *argv[]) { exit(1); } - std::string lats_rspecifier, rnn_wordlist, + std::string lats_rspecifier, rnn_wordlist, word_embedding_rxfilename, word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier; - KALDI_ASSERT (po.NumArgs() == 5); + KALDI_ASSERT (po.NumArgs() == 6); rnn_wordlist = po.GetArg(1); word_symbols_rxfilename = po.GetArg(2); lats_rspecifier = po.GetArg(3); rnnlm_rxfilename = po.GetArg(4); - lats_wspecifier = po.GetArg(5); + word_embedding_rxfilename = po.GetArg(5); + lats_wspecifier = po.GetArg(6); // Reads the language model. kaldi::nnet3::Nnet rnnlm; diff --git a/src/rnnlm/Makefile b/src/rnnlm/Makefile index 1e57d2f77ad..04228a08201 100644 --- a/src/rnnlm/Makefile +++ b/src/rnnlm/Makefile @@ -10,7 +10,7 @@ TESTFILES = sampler-test sampling-lm-test rnnlm-example-test OBJFILES = sampler.o rnnlm-example.o rnnlm-example-utils.o \ rnnlm-core-training.o rnnlm-embedding-training.o rnnlm-core-compute.o \ rnnlm-utils.o rnnlm-training.o rnnlm-test-utils.o sampling-lm-estimate.o \ - sampling-lm.o kaldi-rnnlm-decodable-simple-looped.o + sampling-lm.o rnnlm-decodable-simple-looped.o rnnlm-lattice-rescoring.o LIBNAME = kaldi-rnnlm diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc b/src/rnnlm/rnnlm-decodable-simple-looped.cc similarity index 98% rename from src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc rename to src/rnnlm/rnnlm-decodable-simple-looped.cc index ba298e417d3..c6de5a549a3 100644 --- a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc +++ b/src/rnnlm/rnnlm-decodable-simple-looped.cc @@ -18,7 +18,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "rnnlm/kaldi-rnnlm-decodable-simple-looped.h" +#include "rnnlm/rnnlm-decodable-simple-looped.h" #include "nnet3/nnet-utils.h" #include "nnet3/nnet-compile-looped.h" @@ -79,7 +79,7 @@ DecodableRnnlmSimpleLooped::DecodableRnnlmSimpleLooped( // since everytime we provide one chunk to the decodable object, the size of // feats_ == frames_per_chunk feats_(info_.frames_per_chunk, - info_.word_embedding_mat.NumRows()), // or Cols()? TODO(hxu) + info_.word_embedding_mat.NumRows()), // or Cols()? TODO(hxu), should be vocab size current_log_post_offset_(-1) { num_frames_ = feats_.NumRows(); diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h b/src/rnnlm/rnnlm-decodable-simple-looped.h similarity index 100% rename from src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h rename to src/rnnlm/rnnlm-decodable-simple-looped.h diff --git a/src/rnnlm/rnnlm-lattice-rescoring.cc b/src/rnnlm/rnnlm-lattice-rescoring.cc new file mode 100644 index 00000000000..8d7ab58e538 --- /dev/null +++ b/src/rnnlm/rnnlm-lattice-rescoring.cc @@ -0,0 +1,161 @@ +// rnnlm/rnnlm-lattice-rescoring.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// Yiming Wang +// Hainan Xu +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "rnnlm/rnnlm-lattice-rescoring.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" + +namespace kaldi { +namespace nnet3 { + +void KaldiRnnlmDeterministicFst::ReadFstWordSymbolTableAndRnnWordlist( + const std::string &rnn_wordlist, +// const std::string &rnn_out_wordlist, + const std::string &word_symbol_table_rxfilename) { + // Reads symbol table. + fst::SymbolTable *fst_word_symbols = NULL; + if (!(fst_word_symbols = + fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) { + KALDI_ERR << "Could not read symbol table from file " + << word_symbol_table_rxfilename; + } + + full_voc_size_ = fst_word_symbols->NumSymbols(); + fst_label_to_word_.resize(full_voc_size_); + + for (int32 i = 0; i < fst_label_to_word_.size(); ++i) { + fst_label_to_word_[i] = fst_word_symbols->Find(i); + if (fst_label_to_word_[i] == "") { + KALDI_ERR << "Could not find word for integer " << i << "in the word " + << "symbol table, mismatched symbol table or you have discoutinuous " + << "integers in your symbol table?"; + } + } + +// fst_label_to_rnn_out_label_.resize(fst_word_symbols->NumSymbols(), -1); + fst_label_to_rnn_label_.resize(fst_word_symbols->NumSymbols(), -1); + + out_OOS_index_ = 1; + { + std::ifstream ifile(rnn_wordlist.c_str()); + int32 id; + string word; + int32 i = 0; + while (ifile >> word >> id) { + if (word == "") { + KALDI_ASSERT(id == out_OOS_index_); + } + KALDI_ASSERT(i == id); + i++; + rnn_label_to_word_.push_back(word); + + int fst_label = fst_word_symbols->Find(rnn_label_to_word_[id]); + KALDI_ASSERT(fst::SymbolTable::kNoSymbol != fst_label || id == out_OOS_index_ || id == 0); + if (id != out_OOS_index_ && out_OOS_index_ != 0) { + fst_label_to_rnn_label_[fst_label] = id; + } + } + } + + for (int32 i = 0; i < fst_label_to_rnn_label_.size(); i++) { + if (fst_label_to_rnn_label_[i] == -1) { + fst_label_to_rnn_label_[i] = out_OOS_index_; + } + } +} + +KaldiRnnlmDeterministicFst::KaldiRnnlmDeterministicFst(int32 max_ngram_order, + const std::string &rnn_wordlist, + const std::string &word_symbol_table_rxfilename, + const DecodableRnnlmSimpleLoopedInfo &info) { + max_ngram_order_ = max_ngram_order; + ReadFstWordSymbolTableAndRnnWordlist(rnn_wordlist, + word_symbol_table_rxfilename); + + std::vector