Skip to content

Commit

Permalink
Adding sgmm2 code (not all fully finished, but compiles).
Browse files Browse the repository at this point in the history
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@1025 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
  • Loading branch information
danpovey committed Jun 12, 2012
1 parent 0279c08 commit b46e414
Show file tree
Hide file tree
Showing 47 changed files with 8,738 additions and 231 deletions.
1 change: 1 addition & 0 deletions src/TODO
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

=====
dan's TODO:
address roundoff issues RE lattice generation?
wali_to_ctm.sh assumes that conv-side is spk-{A,B}, not always true.
eventually remove warning RE thread-test
fix SGMM w/ resizing spk vecs.
Expand Down
3 changes: 2 additions & 1 deletion src/decoder/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ TESTFILES =

OBJFILES = decodable-am-diag-gmm.o training-graph-compiler.o decodable-am-sgmm.o \
decodable-am-tied-diag-gmm.o decodable-am-tied-full-gmm.o \
lattice-simple-decoder.o lattice-faster-decoder.o faster-decoder.o
lattice-simple-decoder.o lattice-faster-decoder.o faster-decoder.o \
decodable-am-sgmm2.o

LIBFILE = kaldi-decoder.a

Expand Down
2 changes: 1 addition & 1 deletion src/decoder/decodable-am-diag-gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class DecodableAmDiagGmmUnmapped : public DecodableInterface {
return (frame == NumFrames() - 1);
}

void ResetLogLikeCache();
protected:
void ResetLogLikeCache();
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 state_index);

const AmDiagGmm &acoustic_model_;
Expand Down
41 changes: 0 additions & 41 deletions src/decoder/decodable-am-sgmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,4 @@ void DecodableAmSgmm::ResetLogLikeCache() {
for (; it != end; ++it) { it->hit_time = -1; }
}

BaseFloat DecodableAmSgmmFmllr::LogLikelihoodZeroBased(int32 frame, int32 state) {
KALDI_ASSERT(frame >= 0 && frame < NumFrames());
KALDI_ASSERT(state >= 0 && state < NumIndices());

if (log_like_cache_[state].hit_time == frame) {
return log_like_cache_[state].log_like; // return cached value, if found
}

const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
// check if everything is in order
if (acoustic_model_.FeatureDim() != data.Dim()) {
KALDI_ERR << "Dim mismatch: data dim = " << data.Dim()
<< "vs. model dim = " << acoustic_model_.FeatureDim();
}

if (frame != previous_frame_) { // Per-frame precomputation for SGMM.
int32 dim = acoustic_model_.FeatureDim();
Vector<BaseFloat> extended_data(dim+1, kUndefined);
extended_data.Range(0, dim).CopyFromVec(data);
extended_data(dim) = 1.0;
xformed_feat_.AddMatVec(1.0, fmllr_mat_, kNoTrans, extended_data, 0.0);
if (gselect_all_.empty())
acoustic_model_.GaussianSelection(sgmm_config_, xformed_feat_, &gselect_);
else {
KALDI_ASSERT(frame < gselect_all_.size());
gselect_ = gselect_all_[frame];
}
acoustic_model_.ComputePerFrameVars(xformed_feat_, gselect_, spk_, logdet_,
&per_frame_vars_);
previous_frame_ = frame;
}

BaseFloat loglike = acoustic_model_.LogLikelihood(per_frame_vars_, state,
log_prune_);
if (KALDI_ISNAN(loglike) || KALDI_ISINF(loglike))
KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
log_like_cache_[state].log_like = loglike;
log_like_cache_[state].hit_time = frame;
return loglike;
}

} // namespace kaldi
37 changes: 1 addition & 36 deletions src/decoder/decodable-am-sgmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ class DecodableAmSgmm : public DecodableInterface {
return (frame == NumFrames() - 1);
}

void ResetLogLikeCache();

protected:
void ResetLogLikeCache();
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 pdf_id);

const AmSgmm &acoustic_model_;
Expand Down Expand Up @@ -112,40 +111,6 @@ class DecodableAmSgmmScaled : public DecodableAmSgmm {
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmSgmmScaled);
};

class DecodableAmSgmmFmllr : public DecodableAmSgmm {
public:
DecodableAmSgmmFmllr(const SgmmGselectConfig &opts,
const AmSgmm &am,
const SgmmPerSpkDerivedVars &spk, // may be empty
const TransitionModel &tm,
const Matrix<BaseFloat> &feats,
const std::vector<std::vector<int32> > &gselect_all,
// gselect_all may be empty
BaseFloat log_prune,
BaseFloat scale,
const MatrixBase<BaseFloat> &fmllr)
: DecodableAmSgmm(opts, am, spk, tm, feats, gselect_all, log_prune),
fmllr_mat_(fmllr), xformed_feat_(am.FeatureDim()), scale_(scale) {
int32 dim = am.FeatureDim();
logdet_ = fmllr_mat_.Range(0, dim, 0, dim).LogDet();
}

// Note, frames are numbered from zero but transition-ids from one.
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
return LogLikelihoodZeroBased(frame, trans_model_.TransitionIdToPdf(tid))
* scale_;
}

protected:
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 pdf_id);

private:
Matrix<BaseFloat> fmllr_mat_;
Vector<BaseFloat> xformed_feat_;
BaseFloat scale_, logdet_;
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmSgmmFmllr);
};


} // namespace kaldi

Expand Down
43 changes: 43 additions & 0 deletions src/decoder/decodable-am-sgmm2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// decoder/decodable-am-sgmm2.cc

// Copyright 2009-2012 Saarland University; Lukas Burget;
// Johns Hopkins University (author: Daniel Povey)

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include <vector>
using std::vector;

#include "decoder/decodable-am-sgmm2.h"

namespace kaldi {

BaseFloat DecodableAmSgmm2::LogLikelihoodForPdf(int32 frame, int32 pdf_id) {
if (frame != cur_frame_) {
cur_frame_ = frame;
sgmm_cache_.NextFrame(); // it has a frame-index internally but it doesn't
// have to match up with our index here, it just needs to be unique.


SubVector<BaseFloat> data(feature_matrix_, frame);

sgmm_.ComputePerFrameVars(data, gselect_[frame], *spk_,
&per_frame_vars_);
}
return sgmm_.LogLikelihood(per_frame_vars_, pdf_id, &sgmm_cache_, spk_,
log_prune_);
}


} // namespace kaldi
103 changes: 103 additions & 0 deletions src/decoder/decodable-am-sgmm2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// decoder/decodable-am-sgmm2.h

// Copyright 2009-2012 Saarland University Microsoft Corporation
// Lukas Burget Johns Hopkins University (author: Daniel Povey)

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_DECODER_DECODABLE_AM_SGMM2_H_
#define KALDI_DECODER_DECODABLE_AM_SGMM2_H_

#include <vector>

#include "base/kaldi-common.h"
#include "sgmm2/am-sgmm.h"
#include "hmm/transition-model.h"
#include "itf/decodable-itf.h"

namespace kaldi {

class DecodableAmSgmm2 : public DecodableInterface {
public:
DecodableAmSgmm2(const AmSgmm2 &sgmm,
const TransitionModel &tm,
const Matrix<BaseFloat> &feats,
const std::vector<std::vector<int32> > &gselect,
BaseFloat log_prune,
Sgmm2PerSpkDerivedVars *spk):
sgmm_(sgmm), spk_(spk),
trans_model_(tm), feature_matrix_(feats),
gselect_(gselect), log_prune_(log_prune), cur_frame_(-1),
sgmm_cache_(sgmm.NumGroups(), sgmm.NumPdfs()) {
KALDI_ASSERT(gselect.size() == static_cast<size_t>(feats.NumRows()));
}

// Note, frames are numbered from zero, but transition indices are 1-based!
// This is for compatibility with OpenFST.
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdf(tid));
}
int32 NumFrames() { return feature_matrix_.NumRows(); }
virtual int32 NumIndices() { return trans_model_.NumTransitionIds(); }

virtual bool IsLastFrame(int32 frame) {
KALDI_ASSERT(frame < NumFrames());
return (frame == NumFrames() - 1);
}

protected:
virtual BaseFloat LogLikelihoodForPdf(int32 frame, int32 pdf_id);

const AmSgmm2 &sgmm_;
Sgmm2PerSpkDerivedVars *spk_;
const TransitionModel &trans_model_; ///< for tid to pdf mapping
const Matrix<BaseFloat> &feature_matrix_;
const std::vector<std::vector<int32> > gselect_;

BaseFloat log_prune_;

int32 cur_frame_;
Sgmm2PerFrameDerivedVars per_frame_vars_;
Sgmm2LikelihoodCache sgmm_cache_;

private:
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmSgmm2);
};

class DecodableAmSgmm2Scaled : public DecodableAmSgmm2 {
public:
DecodableAmSgmm2Scaled(const AmSgmm2 &sgmm,
const TransitionModel &tm,
const Matrix<BaseFloat> &feats,
const std::vector<std::vector<int32> > &gselect,
BaseFloat log_prune,
BaseFloat scale,
Sgmm2PerSpkDerivedVars *spk)
: DecodableAmSgmm2(sgmm, tm, feats, gselect, log_prune, spk),
scale_(scale) {}

// Note, frames are numbered from zero but transition-ids from one.
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdf(tid))
* scale_;
}
private:
BaseFloat scale_;
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmSgmm2Scaled);
};


} // namespace kaldi

#endif // KALDI_DECODER_DECODABLE_AM_SGMM_H_
4 changes: 3 additions & 1 deletion src/decoder/lattice-faster-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ bool LatticeFasterDecoder::GetRawLattice(fst::MutableFst<LatticeArc> *ofst) cons
if (f == 0 && ofst->NumStates() > 0)
ofst->SetStart(ofst->NumStates()-1);
}
KALDI_VLOG(3) << "init:" << num_toks_/2 + 3 << " buckets:" << tok_map.bucket_count() << " load:" << tok_map.load_factor() << " max:" << tok_map.max_load_factor();
KALDI_VLOG(3) << "init:" << num_toks_/2 + 3 << " buckets:"
<< tok_map.bucket_count() << " load:" << tok_map.load_factor()
<< " max:" << tok_map.max_load_factor();
// Now create all arcs.
StateId cur_state = 0; // we rely on the fact that we numbered these
// consecutively (AddState() returns the numbers in order..)
Expand Down
70 changes: 4 additions & 66 deletions src/gmm/am-diag-gmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,77 +100,15 @@ int32 AmDiagGmm::ComputeGconsts() {
return num_bad;
}

struct CountStats {
CountStats(int32 p, int32 n, BaseFloat occ)
: pdf_index(p), num_components(n), occupancy(occ) {}
int32 pdf_index;
int32 num_components;
BaseFloat occupancy;
bool operator < (const CountStats &other) const {
return occupancy/(num_components+1.0e-10) <
other.occupancy/(other.num_components+1.0e-10);
}
};

void AmDiagGmm::ComputeTargetNumPdfs(const Vector<BaseFloat> &state_occs,
int32 target_components,
BaseFloat power,
BaseFloat min_count,
std::vector<int32> *targets) const {
KALDI_ASSERT(static_cast<int32>(state_occs.Dim()) == NumPdfs());

std::priority_queue<CountStats> split_queue;
int32 current_components = 0;
for (int32 pdf_index = 0, num_pdf = NumPdfs(); pdf_index < num_pdf;
++pdf_index) {
BaseFloat occ = pow(state_occs(pdf_index), power);
// initialize with one Gaussian per PDF, to put a floor
// of 1 on the #Gauss
split_queue.push(CountStats(pdf_index, 1, occ));
current_components += densities_[pdf_index]->NumGauss();
}
KALDI_ASSERT(current_components == NumGauss());

for (int32 num_gauss = NumPdfs(); // since we initialized with 1 per PDF.
num_gauss < target_components;
++num_gauss) {
CountStats state_to_split = split_queue.top();
if (state_to_split.occupancy == 0) {
KALDI_WARN << "Could not split up to " << target_components
<< " due to min-count = " << min_count
<< " (or no counts at all)\n";
break;
}
split_queue.pop();
BaseFloat orig_occ = state_occs(state_to_split.pdf_index);
if ((state_to_split.num_components+1) * min_count >= orig_occ) {
state_to_split.occupancy = 0; // min-count active -> disallow splitting
// this state any more by setting occupancy = 0.
} else {
state_to_split.num_components++;
}
split_queue.push(state_to_split);
}

targets->resize(NumPdfs());

current_components = 0;
while (!split_queue.empty()) {
int32 pdf_index = split_queue.top().pdf_index;
int32 pdf_tgt_comp = split_queue.top().num_components;
(*targets)[pdf_index] = pdf_tgt_comp;
split_queue.pop();
}
}

void AmDiagGmm::SplitByCount(const Vector<BaseFloat> &state_occs,
int32 target_components,
float perturb_factor, BaseFloat power,
BaseFloat min_count) {
int32 gauss_at_start = NumGauss();
std::vector<int32> targets;
ComputeTargetNumPdfs(state_occs, target_components, power,
min_count, &targets);
GetSplitTargets(state_occs, target_components, power,
min_count, &targets);

for (int32 i = 0; i < NumPdfs(); i++) {
if (densities_[i]->NumGauss() < targets[i])
Expand All @@ -192,8 +130,8 @@ void AmDiagGmm::MergeByCount(const Vector<BaseFloat> &state_occs,
BaseFloat min_count) {
int32 gauss_at_start = NumGauss();
std::vector<int32> targets;
ComputeTargetNumPdfs(state_occs, target_components,
power, min_count, &targets);
GetSplitTargets(state_occs, target_components,
power, min_count, &targets);

for (int32 i = 0; i < NumPdfs(); i++) {
if (targets[i] == 0) targets[i] = 1; // can't merge below 1.
Expand Down
7 changes: 0 additions & 7 deletions src/gmm/am-diag-gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,6 @@ class AmDiagGmm {

void RemovePdf(int32 pdf_index);

// used in SplitByCount and MergeByCount.
void ComputeTargetNumPdfs(const Vector<BaseFloat> &state_occs,
int32 target_components,
BaseFloat power,
BaseFloat min_count,
std::vector<int32> *targets) const;

KALDI_DISALLOW_COPY_AND_ASSIGN(AmDiagGmm);
};

Expand Down
Loading

0 comments on commit b46e414

Please sign in to comment.