Skip to content

Commit

Permalink
use virtual function to allow any model type to be loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
joshemorris committed Oct 9, 2017
1 parent 40d3883 commit f4a4760
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions pytorch_ctc/src/ctc_beam_scorer_klm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ namespace ctc_beam_search {
float delta_score;
std::wstring incomplete_word;
TrieNode *incomplete_word_trie_node;
lm::ngram::ProbingModel::State model_state;
lm::ngram::State model_state;
};
}

using pytorch::ctc::ctc_beam_search::KenLMBeamState;

class KenLMBeamScorer : public BaseBeamScorer<KenLMBeamState> {
public:
typedef lm::ngram::ProbingModel Model;

virtual ~KenLMBeamScorer() {
delete model;
Expand All @@ -59,8 +58,7 @@ namespace ctc_beam_search {
valid_word_count_weight(1.0f) {
lm::ngram::Config config;
config.load_method = util::POPULATE_OR_READ;
model = new Model(kenlm_path, config);

model = lm::ngram::LoadVirtual(kenlm_path, config);
this->labels = labels;

std::ifstream in;
Expand All @@ -76,7 +74,7 @@ namespace ctc_beam_search {
root->delta_score = 0.0f;
root->incomplete_word.clear();
root->incomplete_word_trie_node = trieRoot;
root->model_state = model->BeginSentenceState();
model->BeginSentenceWrite(&root->model_state);
}
// ExpandState is called when expanding a beam to one of its children.
// Called at most once per child beam. In the simplest case, no state
Expand Down Expand Up @@ -126,17 +124,17 @@ namespace ctc_beam_search {
// and retrieving the TopN requested candidates. Called at most once per beam.
void ExpandStateEnd(KenLMBeamState* state) const {
float lm_score_delta = 0.0f;
Model::State out;
lm::ngram::State out;
if (state->incomplete_word.size() > 0) {
lm_score_delta += ScoreIncompleteWord(state->model_state,
state->incomplete_word,
out);
ResetIncompleteWord(state);
state->model_state = out;
}
lm_score_delta += model->FullScore(state->model_state,
model->GetVocabulary().EndSentence(),
out).prob;
lm_score_delta += model->BaseFullScore(&state->model_state,
model->BaseVocabulary().EndSentence(),
&out).prob;
UpdateWithLMScore(state, lm_score_delta);
}
// GetStateExpansionScore should be an inexpensive method to retrieve the
Expand Down Expand Up @@ -174,7 +172,7 @@ namespace ctc_beam_search {
private:
Labels *labels;
TrieNode *trieRoot;
Model *model;
lm::base::Model *model;
float lm_weight;
float word_count_weight;
float valid_word_count_weight;
Expand All @@ -194,19 +192,19 @@ namespace ctc_beam_search {
bool IsOOV(const std::wstring& word) const {
std::string encoded_word;
utf8::utf16to8(word.begin(), word.end(), std::back_inserter(encoded_word));
auto &vocabulary = model->GetVocabulary();
auto &vocabulary = model->BaseVocabulary();
return vocabulary.Index(encoded_word) == vocabulary.NotFound();
}

float ScoreIncompleteWord(const Model::State& model_state,
float ScoreIncompleteWord(const lm::ngram::State& model_state,
const std::wstring& word,
Model::State& out) const {
lm::ngram::State& out) const {
lm::FullScoreReturn full_score_return;
lm::WordIndex vocab;
std::string encoded_word;
utf8::utf16to8(word.begin(), word.end(), std::back_inserter(encoded_word));
vocab = model->GetVocabulary().Index(encoded_word);
full_score_return = model->FullScore(model_state, vocab, out);
vocab = model->BaseVocabulary().Index(encoded_word);
full_score_return = model->BaseFullScore(&model_state, vocab, &out);
return full_score_return.prob;
}

Expand Down

0 comments on commit f4a4760

Please sign in to comment.