Skip to content

Commit

Permalink
Fix test and seg fault
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Jun 26, 2017
1 parent 90a1463 commit bd6b3cd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 6 additions & 3 deletions pytorch_ctc/src/ctc_beam_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ struct BeamEntry {
return;
}
children = std::vector<BeamEntry>(L-1);
for (int ci = 0; ci < L; ++ci) {
if (ci == blank) continue;
for (int ci=0,cl=0; ci < L; ++ci) {
if (ci == blank) {
continue;
}
// The current object cannot be copied, and should not be moved.
// Otherwise the child's parent will become invalid.
auto& c = children[ci];
auto& c = children[cl];
c.parent = this;
c.label = ci;
++cl;
}
}
inline std::vector<BeamEntry>* Children() {
Expand Down
6 changes: 3 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_simple_decode_different_blank_idx(self):
aa = torch.FloatTensor(np.array([[[0.0, 1.0]], [[0.0, 1.0]], [[1.0, 0.0]], [[0.0, 1.0]], [[0.0, 1.0]]], dtype=np.float32)).log()
seq_len = torch.IntTensor(np.array([5], dtype=np.int32))

labels="A_"
labels="_A"
scorer = pytorch_ctc.Scorer()
decoder_merge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=1, beam_width=1, merge_repeated=True)
decoder_nomerge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=1, beam_width=1, merge_repeated=False)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_ctc_decoder_beam_search(self):
self.assertEqual(decode_result.numpy()[1,0,:decode_len[1][0]].tolist(), [0, 1, 0])
np.testing.assert_almost_equal(scores.numpy(), np.array([[-0.584855], [-0.389139]]), 5)

def test_ctc_decoder_beam_search_different_blank_ids(self):
def test_ctc_decoder_beam_search_different_blank_idx(self):
depth = 6
seq_len_0 = 5
input_prob_matrix_0 = np.asarray(
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_ctc_decoder_beam_search_different_blank_ids(self):
th_input = torch.from_numpy(inputs)
th_seq_len = torch.IntTensor(seq_lens)

labels="ABCDE_"
labels="_ABCDE"
scorer = pytorch_ctc.Scorer()
decoder = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=2, beam_width=2, merge_repeated=False)

Expand Down

0 comments on commit bd6b3cd

Please sign in to comment.