Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnet3-rnnlm lattice rescoring draft #1906

Merged
merged 27 commits into from
Nov 23, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0d839b0
draft
hainan-xv Sep 18, 2017
699c956
lattice-rescoring draft finished
hainan-xv Sep 20, 2017
ef09b62
lattice-rescoring runnable but buggy
hainan-xv Sep 22, 2017
390a1bb
making a PR
hainan-xv Sep 24, 2017
dc49709
small changes
hainan-xv Sep 24, 2017
8a33e77
include lmrescore_rnn_lat.sh
hainan-xv Sep 24, 2017
5965b87
Merge branch 'rnnlm' into rnnlm-rescoring
hainan-xv Sep 24, 2017
483450d
some aesthetic changes; not final yet
hainan-xv Sep 25, 2017
00912f7
Merge branch 'master' into rnnlm-rescoring
hainan-xv Sep 25, 2017
b1167a2
cached version of lattice rescoring; buggy it seems
hainan-xv Sep 27, 2017
a52da29
purely aesthetic changes
hainan-xv Sep 27, 2017
3bdaa4d
re-written some of the classes
hainan-xv Sep 28, 2017
2b08335
very small changes
hainan-xv Sep 28, 2017
7cf4af8
fix a typo
hainan-xv Sep 28, 2017
8f35242
make RNNLM share the same FST wordlist
hainan-xv Oct 2, 2017
705ecc8
fix small issue when running lattice-rescoring with normalize-probs o…
hainan-xv Oct 2, 2017
d19ecc1
minor changes
hainan-xv Oct 6, 2017
232ef04
fix small stylistic issues in code
hainan-xv Oct 14, 2017
bd9936b
fix wrong variable used in scripts/rnnlm/lmrescore_rnnlm_lat.sh
hainan-xv Oct 14, 2017
9cc7ba1
add rnnlm softlink in swbd/s5c
hainan-xv Oct 15, 2017
267177f
small style changes
hainan-xv Oct 30, 2017
87f2f6c
merge with latest upstream
hainan-xv Nov 3, 2017
c9bf5e0
move rescoring into rnnlm training scripts
hainan-xv Nov 7, 2017
091d4d5
move rescoring into rnnlm training scripts
hainan-xv Nov 8, 2017
a192ada
fix small issues mentioned by @danoneata
hainan-xv Nov 9, 2017
697f219
change SWBD script to accommodate s5_c; add paper link to RNNLM scrip…
hainan-xv Nov 20, 2017
acb5211
fix conflicts
hainan-xv Nov 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
draft
  • Loading branch information
hainan-xv committed Sep 18, 2017
commit 0d839b0bf66ba09b930328b2fb313343c4d8b18e
147 changes: 147 additions & 0 deletions src/latbin/lattice-lmrescore-nnet3-rnnlm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// latbin/lattice-lmrescore-nnet3-rnnlm.cc

// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// Yiming Wang

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.


#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "rnnlm/kaldi-rnnlm-rescoring.h"
#include "util/common-utils.h"

int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;

const char *usage =
"Rescores lattice with rnnlm. The LM will be wrapped into the\n"
"DeterministicOnDemandFst interface and the rescoring is done by\n"
"composing with the wrapped LM using a special type of composition\n"
"algorithm. Determinization will be applied on the composed lattice.\n"
"\n"
"Usage: lattice-lmrescore-nnet3-rnnlm [options] <rnnlm-wordlist> \\\n"
" <word-symbol-table-rxfilename> <lattice-rspecifier> \\\n"
" <rnnlm-rxfilename> <lattice-wspecifier>\n"
" e.g.: lattice-lmrescore-nnet3-rnnlm --lm-scale=-1.0 words.txt \\\n"
" ark:in.lats rnnlm ark:out.lats\n";

ParseOptions po(usage);
int32 max_ngram_order = 3;
BaseFloat lm_scale = 1.0;

po.Register("lm-scale", &lm_scale, "Scaling factor for language model "
"costs; frequently 1.0 or -1.0");
po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the "
"rnnlm context to the given number, -1 means we are not going "
"to limit it.");

po.Read(argc, argv);

if (po.NumArgs() != 4 && po.NumArgs() != 5) {
po.PrintUsage();
exit(1);
}

std::string lats_rspecifier, rnn_wordlist,
word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier;
KALDI_ASSERT (po.NumArgs() == 5);

rnn_wordlist = po.GetArg(1);
word_symbols_rxfilename = po.GetArg(2);
lats_rspecifier = po.GetArg(3);
rnnlm_rxfilename = po.GetArg(4);
lats_wspecifier = po.GetArg(5);

// Reads the language model.
kaldi::nnet3::Nnet rnnlm;
ReadKaldiObject(rnnlm_rxfilename, &rnnlm);

if (!IsSimpleNnet(rnnlm))
KALDI_ERR << "Input RNNLM in " << rnnlm_rxfilename
<< " is not the type of neural net we were looking for; "
"failed IsSimpleNnet().";

CuMatrix<BaseFloat> word_embedding_mat;
ReadKaldiObject(word_embedding_rxfilename, &word_embedding_mat);

const nnet3::DecodableRnnlmSimpleLoopedComputationOptions opts;
const nnet3::DecodableRnnlmSimpleLoopedInfo info(opts, rnnlm, word_embedding_mat);

// Reads and writes as compact lattice.
SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
CompactLatticeWriter compact_lattice_writer(lats_wspecifier);

int32 n_done = 0, n_fail = 0;
for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
std::string key = compact_lattice_reader.Key();
CompactLattice clat = compact_lattice_reader.Value();
compact_lattice_reader.FreeCurrent();

if (lm_scale != 0.0) {
// Before composing with the LM FST, we scale the lattice weights
// by the inverse of "lm_scale". We'll later scale by "lm_scale".
// We do it this way so we can determinize and it will give the
// right effect (taking the "best path" through the LM) regardless
// of the sign of lm_scale.
fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat);
ArcSort(&clat, fst::OLabelCompare<CompactLatticeArc>());

// Wraps the rnnlm into FST. We re-create it for each lattice to prevent
// memory usage increasing with time.
nnet3::KaldiRnnlmDeterministicFst rnnlm_fst(max_ngram_order,
rnn_wordlist,
word_symbols_rxfilename,
info);

// Composes lattice with language model.
CompactLattice composed_clat;
ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat);

// Determinizes the composed lattice.
Lattice composed_lat;
ConvertLattice(composed_clat, &composed_lat);
Invert(&composed_lat);
CompactLattice determinized_clat;
DeterminizeLattice(composed_lat, &determinized_clat);
fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat);
if (determinized_clat.Start() == fst::kNoStateId) {
KALDI_WARN << "Empty lattice for utterance " << key
<< " (incompatible LM?)";
n_fail++;
} else {
compact_lattice_writer.Write(key, determinized_clat);
n_done++;
}
} else {
// Zero scale so nothing to do.
n_done++;
compact_lattice_writer.Write(key, clat);
}
}

KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail;
return (n_done != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
2 changes: 1 addition & 1 deletion src/rnnlm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TESTFILES = sampler-test sampling-lm-test rnnlm-example-test
OBJFILES = sampler.o rnnlm-example.o rnnlm-example-utils.o \
rnnlm-core-training.o rnnlm-embedding-training.o rnnlm-core-compute.o \
rnnlm-utils.o rnnlm-training.o rnnlm-test-utils.o sampling-lm-estimate.o \
sampling-lm.o
sampling-lm.o kaldi-rnnlm-decodable-simple-looped.o

LIBNAME = kaldi-rnnlm

Expand Down
182 changes: 182 additions & 0 deletions src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// rnnlm/kaldi-rnnlm-decodable-simple-looped.cc

// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// 2017 Yiming Wang

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "rnnlm/kaldi-rnnlm-decodable-simple-looped.h"
#include "nnet3/nnet-utils.h"
#include "nnet3/nnet-compile-looped.h"

namespace kaldi {
namespace nnet3 {


DecodableRnnlmSimpleLoopedInfo::DecodableRnnlmSimpleLoopedInfo(
const DecodableRnnlmSimpleLoopedComputationOptions &opts,
const kaldi::nnet3::Nnet &rnnlm,
const CuMatrix<BaseFloat> &word_embedding_mat):
opts(opts), rnnlm(rnnlm), word_embedding_mat(word_embedding_mat) {
Init(opts, rnnlm, word_embedding_mat);
}

void DecodableRnnlmSimpleLoopedInfo::Init(
const DecodableRnnlmSimpleLoopedComputationOptions &opts,
const kaldi::nnet3::Nnet &rnnlm,
const CuMatrix<BaseFloat> &word_embedding_mat) {
opts.Check();
KALDI_ASSERT(IsSimpleNnet(rnnlm));
int32 left_context, right_context;
ComputeSimpleNnetContext(rnnlm, &left_context, &right_context);
frames_left_context = opts.extra_left_context_initial + left_context;
frames_right_context = right_context;
int32 frame_subsampling_factor = 1;
frames_per_chunk = GetChunkSize(rnnlm, frame_subsampling_factor,
opts.frames_per_chunk);
KALDI_ASSERT(frames_per_chunk == opts.frames_per_chunk);
nnet_output_dim = rnnlm.OutputDim("output");
KALDI_ASSERT(nnet_output_dim > 0);

int32 ivector_period = frames_per_chunk;
int32 extra_right_context = 0;
int32 num_sequences = 1; // we're processing one utterance at a time.
CreateLoopedComputationRequestSimple(rnnlm, frames_per_chunk,
frame_subsampling_factor,
ivector_period,
opts.extra_left_context_initial,
extra_right_context,
num_sequences,
&request1, &request2, &request3);

CompileLooped(rnnlm, opts.optimize_config, request1, request2,
request3, &computation);
computation.ComputeCudaIndexes();
if (GetVerboseLevel() >= 3) {
KALDI_VLOG(3) << "Computation is:";
computation.Print(std::cerr, rnnlm);
}
}

DecodableRnnlmSimpleLooped::DecodableRnnlmSimpleLooped(
const DecodableRnnlmSimpleLoopedInfo &info) :
info_(info),
computer_(info_.opts.compute_config, info_.computation,
info_.rnnlm, NULL), // NULL is 'nnet_to_update'
// since everytime we provide one chunk to the decodable object, the size of
// feats_ == frames_per_chunk
feats_(info_.frames_per_chunk,
info_.word_embedding_mat.NumRows()), // or Cols()? TODO(hxu)
current_log_post_offset_(-1)
{
num_frames_ = feats_.NumRows();
}

void DecodableRnnlmSimpleLooped::TakeFeatures(
const std::vector<int32> &word_indexes) {
KALDI_ASSERT(word_indexes.size() == num_frames_);
std::vector<std::vector<std::pair<MatrixIndexT, BaseFloat> > >
pairs(word_indexes.size());
for (int32 i = 0; i < word_indexes.size(); i++) {
std::pair<MatrixIndexT, BaseFloat> one_hot_index(word_indexes[i], 1.0);
std::vector<std::pair<MatrixIndexT, BaseFloat> > row(1, one_hot_index);
pairs[i] = row;
}
SparseMatrix<BaseFloat> feats_temp(feats_.NumCols(), pairs);
feats_.Swap(&feats_temp);
// resets offset so that AdvanceChunk() would be called in GetOutput() and
// GetNnetOutputForFrame() after taking new features
current_log_post_offset_ = -1;
}

void DecodableRnnlmSimpleLooped::GetNnetOutputForFrame(
int32 frame, VectorBase<BaseFloat> *output) {
KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows());
if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows())
AdvanceChunk();
output->CopyFromVec(current_nnet_output_.Row(frame -
current_log_post_offset_));
}

BaseFloat DecodableRnnlmSimpleLooped::GetOutput(int32 frame, int32 word_index) {
KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows());
if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows())
AdvanceChunk();

// int32 embedding_dim = info_.word_embedding_mat.NumCols();
// int32 num_words = info_.word_embedding_mat.NumRows();

const CuMatrix<BaseFloat> &word_embedding_mat = info_.word_embedding_mat;

CuMatrix<BaseFloat> current_nnet_output_gpu;
current_nnet_output_gpu.Swap(&current_nnet_output_);
const CuSubVector<BaseFloat> hidden(current_nnet_output_gpu,
frame - current_log_post_offset_);
BaseFloat log_prob =
VecVec(hidden, word_embedding_mat.Row(word_index));
// output_layer->ComputeLogprobOfWordGivenHistory(hidden, word_index);
// swap the pointer back so that this function can be called multiple times
// with the same returned value before taking next new feats
current_nnet_output_.Swap(&current_nnet_output_gpu);
return log_prob;
}

void DecodableRnnlmSimpleLooped::AdvanceChunk() {
int32 begin_input_frame, end_input_frame;
begin_input_frame = -info_.frames_left_context;
// note: end is last plus one.
end_input_frame = info_.frames_per_chunk + info_.frames_right_context;
// currently there is no left/right context and frames_per_chunk == 1
KALDI_ASSERT(begin_input_frame == 0 && end_input_frame == 1);

SparseMatrix<BaseFloat> feats_chunk(end_input_frame - begin_input_frame,
feats_.NumCols());
int32 num_features = feats_.NumRows();
for (int32 r = begin_input_frame; r < end_input_frame; r++) {
int32 input_frame = r;
if (input_frame < 0) input_frame = 0;
if (input_frame >= num_features) input_frame = num_features - 1;
feats_chunk.SetRow(r - begin_input_frame, feats_.Row(input_frame));
}

// const rnnlm::LmInputComponent* input_layer = info_.lm_nnet.InputLayer();
// CuMatrix<BaseFloat> new_input(feats_chunk.NumRows(), input_layer->OutputDim());
// input_layer->Propagate(feats_chunk, &new_input);

CuMatrix<BaseFloat> input_embeddings(1, info_.word_embedding_mat.NumRows(), kUndefined);
input_embeddings.Row(0).CopyFromVec(info_.word_embedding_mat.Row(feats_chunk.Row(0).Sum()));
computer_.AcceptInput("input", &input_embeddings);

computer_.Run();

{
// Note: here GetOutput() is used instead of GetOutputDestructive(), since
// here we have recurrence that goes directly from the output, and the call
// to GetOutputDestructive() would cause a crash on the next chunk.
CuMatrix<BaseFloat> output(computer_.GetOutput("output"));

current_nnet_output_.Resize(0, 0);
current_nnet_output_.Swap(&output);
}
KALDI_ASSERT(current_nnet_output_.NumRows() == info_.frames_per_chunk &&
current_nnet_output_.NumCols() == info_.nnet_output_dim);

current_log_post_offset_ = 0;
}


} // namespace nnet3
} // namespace kaldi
Loading