Skip to content

Commit

Permalink
handles memory leaks associated with KenLMBeamScorer
Browse files Browse the repository at this point in the history
  • Loading branch information
joshemorris committed Oct 20, 2017
1 parent ed6bd57 commit 25b425b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28):
self._scorer = ctc._get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(),
trie_path.encode())

# This is a way to make sure the destructor is called for the C++ object
# Frees all the member data items that have allocated memory
def __del__(self):
ctc._free_kenlm_scorer(self._scorer)

def set_lm_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_lm_weight(self._scorer, weight)
Expand Down
5 changes: 5 additions & 0 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ namespace pytorch {
#endif
}

void free_kenlm_scorer(void* kenlm_scorer) {
ctc::KenLMBeamScorer* beam_scorer = static_cast<ctc::KenLMBeamScorer*>(kenlm_scorer);
delete beam_scorer;
}

void set_kenlm_scorer_lm_weight(void *scorer, float weight) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer *beam_scorer = static_cast<ctc::KenLMBeamScorer *>(scorer);
Expand Down
2 changes: 2 additions & 0 deletions pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ typedef enum {
int kenlm_enabled();
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* lm_path, const char* trie_path);
void free_kenlm_scorer(void* kenlm_scorer);

void set_kenlm_scorer_lm_weight(void *scorer, float weight);
void set_kenlm_scorer_wc_weight(void *scorer, float weight);
void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
Expand Down
3 changes: 2 additions & 1 deletion pytorch_ctc/src/ctc_beam_scorer_klm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ namespace ctc_beam_search {
class KenLMBeamScorer : public BaseBeamScorer<KenLMBeamState> {
public:

virtual ~KenLMBeamScorer() {
~KenLMBeamScorer() {
delete model;
delete trieRoot;
delete labels;
}
KenLMBeamScorer(Labels *labels, const char *kenlm_path, const char *trie_path)
: lm_weight(1.0f),
Expand Down

0 comments on commit 25b425b

Please sign in to comment.