Skip to content

Commit

Permalink
Changes to support offset information
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Oct 25, 2017
1 parent 7d6b56d commit b45e950
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 37 deletions.
6 changes: 4 additions & 2 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def decode(self, probs, seq_len=None):
output = torch.IntTensor(self._top_paths, batch_size, max_seq_len)
scores = torch.FloatTensor(self._top_paths, batch_size)
out_seq_len = torch.IntTensor(self._top_paths, batch_size)
alignments = torch.IntTensor(self._top_paths, batch_size, max_seq_len)

result = ctc._ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len)
result = ctc._ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len,
alignments)

return output, scores, out_seq_len
return output, scores, out_seq_len, alignments


class BaseScorer(object):
Expand Down
12 changes: 9 additions & 3 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ namespace pytorch {
}

int ctc_beam_decode(void *void_decoder, DecodeType type, THFloatTensor *th_probs, THIntTensor *th_seq_len, THIntTensor *th_output,
THFloatTensor *th_scores, THIntTensor *th_out_len)
THFloatTensor *th_scores, THIntTensor *th_out_len, THIntTensor *th_alignments)
{
const int64_t max_time = THFloatTensor_size(th_probs, 0);
const int64_t batch_size = THFloatTensor_size(th_probs, 1);
Expand Down Expand Up @@ -167,6 +167,10 @@ namespace pytorch {
for (ctc::CTCDecoder::Output& output : outputs) {
output.resize(batch_size);
}
std::vector<ctc::CTCDecoder::Output> alignments(top_paths);
for (ctc::CTCDecoder::Output& alignment : alignments) {
alignment.resize(batch_size);
}
float score[batch_size][top_paths];
memset(score, 0.0, batch_size*top_paths*sizeof(int));
Eigen::Map<Eigen::MatrixXf> *scores;
Expand All @@ -177,7 +181,7 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<> *decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
if (!stat.ok()) {
return 0;
}
Expand All @@ -188,7 +192,7 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<KenLMBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<KenLMBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
if (!stat.ok()) {
return 0;
}
Expand All @@ -203,13 +207,15 @@ namespace pytorch {
int64_t offset = 0;
for (int b=0; b < batch_size; ++b) {
auto& p_batch = outputs[p][b];
auto& alignment_batch = alignments[p][b];
int64_t num_decoded = p_batch.size();

max_decoded = std::max(max_decoded, num_decoded);
THIntTensor_set2d(th_out_len, p, b, num_decoded);
for (int64_t t=0; t < num_decoded; ++t) {
// TODO: this could be more efficient (significant pointer arithmetic every time currently)
THIntTensor_set3d(th_output, p, b, t, p_batch[t]);
THIntTensor_set3d(th_alignments, p, b, t, alignment_batch[t]);
THFloatTensor_set2d(th_scores, p, b, (*scores)(b, p));
}
}
Expand Down
2 changes: 1 addition & 1 deletion pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void* get_ctc_beam_decoder(int num_classes, int top_paths, int beam_width, int b
/* run decoding */
int ctc_beam_decode(void *decoder, DecodeType type,
THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, THIntTensor *th_out_len);
THFloatTensor *scores, THIntTensor *th_out_len, THIntTensor *th_alignments);


/* utilities */
Expand Down
26 changes: 22 additions & 4 deletions pytorch_ctc/src/ctc_beam_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ struct BeamProbability {
template <class CTCBeamState = EmptyBeamState>
struct BeamEntry {
// Default constructor does not create a vector of children.
BeamEntry() : parent(nullptr), label(-1) {}
BeamEntry() : parent(nullptr), label(-1), time_step(-1) {}
// Constructor giving parent, label, and number of children does
// create a vector of children. The object pointed to by p
// cannot be copied and should not be moved, otherwise parent will
// become invalid.
BeamEntry(BeamEntry* p, int l, int L, int blank) : parent(p), label(l) {
PopulateChildren(L, blank);
BeamEntry(BeamEntry* p, int l, int L, int blank, int t)
: parent(p), label(l), time_step(t) {
PopulateChildren(L, blank, t);
}
inline bool Active() const { return newp.total != kLogZero; }
inline bool HasChildren() const { return !children.empty(); }
void PopulateChildren(int L, int blank) {
void PopulateChildren(int L, int blank, int t) {
if (HasChildren()) {
return;
}
Expand All @@ -74,6 +75,7 @@ struct BeamEntry {
auto& c = children[cl];
c.parent = this;
c.label = ci;
c.time_step = t;
++cl;
}
}
Expand Down Expand Up @@ -104,8 +106,24 @@ struct BeamEntry {
return labels;
}

std::vector<int> TimeStepSeq(bool merge_repeated) const {
std::vector<int> time_steps;
int prev_label = -1;
const BeamEntry *c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
if (!merge_repeated || c->label != prev_label) {
time_steps.push_back(c->time_step);
}
prev_label = c->label;
c = c->parent;
}
std::reverse(time_steps.begin(), time_steps.end());
return time_steps;
}

BeamEntry<CTCBeamState>* parent;
int label;
int time_step;
std::vector<BeamEntry<CTCBeamState> > children;
BeamProbability oldp;
BeamProbability newp;
Expand Down
34 changes: 19 additions & 15 deletions pytorch_ctc/src/ctc_beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ limitations under the License.
#include <memory>

#include "Eigen/Core"
//#include "tensorflow/core/lib/core/errors.h"
#include "util/top_n.h"
//#include "tensorflow/core/platform/logging.h"
//#include "tensorflow/core/platform/macros.h"
//#include "tensorflow/core/platform/types.h"
#include "ctc_beam_entry.h"
#include "ctc_beam_scorer.h"
#include "ctc_decoder.h"
Expand Down Expand Up @@ -95,11 +91,12 @@ class CTCBeamSearchDecoder : public CTCDecoder {
Status Decode(const CTCDecoder::SequenceLength& seq_len,
std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override;
CTCDecoder::ScoreOutput* scores,
std::vector<CTCDecoder::Output>* alignment) override;

// Calculate the next step of the beam search and update the internal state.
template <typename Vector>
void Step(const Vector& log_input_t);
void Step(const Vector& log_input_t, int time_step);

// Retrieve the beam scorer instance used during decoding.
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
Expand All @@ -119,7 +116,9 @@ class CTCBeamSearchDecoder : public CTCDecoder {

// Extract the top n paths at current time step
Status TopPaths(int n, std::vector<std::vector<int>>* paths,
std::vector<float>* log_probs, bool merge_repeated) const;
std::vector<float>* log_probs,
std::vector<std::vector<int>> *alignments,
bool merge_repeated) const;

private:
int beam_width_;
Expand Down Expand Up @@ -148,11 +147,14 @@ template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
const CTCDecoder::SequenceLength& seq_len,
std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores,
std::vector<CTCDecoder::Output>* alignment) {
int batch_size_ = input[0].rows();
// Storage for top paths.
std::vector<std::vector<int>> beams;
std::vector<float> beam_log_probabilities;
std::vector<std::vector<int>> beam_alignments;
int top_n = output->size();
if (std::any_of(output->begin(), output->end(),
[batch_size_](const CTCDecoder::Output& output) -> bool {
Expand All @@ -172,7 +174,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(

for (int t = 0; t < seq_len_b; ++t) {
// Pass log-probabilities for this example + time.
Step(input[t].row(b));
Step(input[t].row(b), t);
} // for (int t...

// O(n * log(n))
Expand All @@ -187,8 +189,8 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
}


Status status =
TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
Status status = TopPaths(top_n, &beams, &beam_log_probabilities, &beam_alignments,
merge_repeated_);

if (!status.ok()) {
return status;
Expand All @@ -208,6 +210,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
// Copy output to the correct beam + batch
(*output)[i][b].swap(beams[i]);
(*scores)(b, i) = -beam_log_probabilities[i];
(*alignment)[i][b].swap(beam_alignments[i]);
}
} // for (int b...
return Status::OK();
Expand All @@ -216,7 +219,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const Vector& raw_input) {
const Vector& raw_input, int time_step) {
Eigen::ArrayXf input = raw_input;
// Remove the max for stability when performing log-prob calculations.
input -= input.maxCoeff();
Expand Down Expand Up @@ -299,7 +302,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
}

if (!b->HasChildren()) {
b->PopulateChildren(num_classes_, blank_index_);
b->PopulateChildren(num_classes_, blank_index_, time_step);
}

for (BeamEntry& c : *b->Children()) {
Expand Down Expand Up @@ -348,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {

// This beam root, and all of its children, will be in memory until
// the next reset.
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_, blank_index_));
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_, blank_index_, -1));
beam_root_->newp.total = 0.0; // ln(1)
beam_root_->newp.blank = 0.0; // ln(1)

Expand All @@ -362,7 +365,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
bool merge_repeated) const {
std::vector<std::vector<int>> *alignments, bool merge_repeated) const {
if (paths == nullptr || log_probs == nullptr) {
return errors::FailedPrecondition(
"Internal paths are null"
Expand Down Expand Up @@ -392,6 +395,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
BeamEntry* e((*branches)[i]);
paths->push_back(e->LabelSeq(merge_repeated));
log_probs->push_back(e->newp.total);
alignments->push_back(e->TimeStepSeq(merge_repeated));
}
return Status::OK();
}
Expand Down
10 changes: 6 additions & 4 deletions pytorch_ctc/src/ctc_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ limitations under the License.

#include "Eigen/Core"
#include "util/status.h"
//#include "tensorflow/core/lib/core/errors.h"
//#include "tensorflow/core/lib/core/status.h"

namespace pytorch {
namespace ctc {
Expand Down Expand Up @@ -48,7 +46,8 @@ class CTCDecoder {
// - scores(b, i) - b = 0 to batch_size; i = 0 to output.size()
virtual Status Decode(const SequenceLength& seq_len,
std::vector<Input>& input,
std::vector<Output>* output, ScoreOutput* scores) = 0;
std::vector<Output>* output, ScoreOutput* scores,
std::vector<Output> *alignment) = 0;

int num_classes() { return num_classes_; }

Expand All @@ -68,7 +67,8 @@ class CTCGreedyDecoder : public CTCDecoder {
Status Decode(const CTCDecoder::SequenceLength& seq_len,
std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override {
CTCDecoder::ScoreOutput* scores,
std::vector<CTCDecoder::Output> *alignment) override {
int batch_size_ = input[0].cols();
if (output->empty() || (*output)[0].size() < batch_size_) {
return errors::InvalidArgument(
Expand All @@ -83,6 +83,7 @@ class CTCGreedyDecoder : public CTCDecoder {
int seq_len_b = seq_len[b];
// Only writing to beam 0
std::vector<int>& output_b = (*output)[0][b];
std::vector<int> &alignment_b = (*alignment)[0][b];

int prev_class_ix = -1;
(*scores)(b, 0) = 0;
Expand All @@ -93,6 +94,7 @@ class CTCGreedyDecoder : public CTCDecoder {
if (max_class_ix != blank_index_ &&
!(merge_repeated_ && max_class_ix == prev_class_ix)) {
output_b.push_back(max_class_ix);
alignment_b.push_back(t);
}
prev_class_ix = max_class_ix;
}
Expand Down
14 changes: 6 additions & 8 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_simple_decode(self):
decoder_merge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=1, space_index=-1, top_paths=1, beam_width=1, merge_repeated=True)
decoder_nomerge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=1, space_index=-1, top_paths=1, beam_width=1, merge_repeated=False)

result_merge, _, result_merge_len = decoder_merge.decode(aa, seq_len)
result_nomerge, _, result_nomerge_len = decoder_nomerge.decode(aa, seq_len)
result_merge, _, result_merge_len, merge_alignments = decoder_merge.decode(aa, seq_len)
result_nomerge, _, result_nomerge_len, nomerge_alignments = decoder_nomerge.decode(aa, seq_len)
self.assertEqual(result_merge_len[0][0], 1)
self.assertEqual(result_nomerge_len[0][0], 2)
self.assertEqual(result_merge.numpy()[0,0,:result_merge_len[0][0]].tolist(), [0])
Expand All @@ -31,8 +31,8 @@ def test_simple_decode_different_blank_idx(self):
decoder_merge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=1, beam_width=1, merge_repeated=True)
decoder_nomerge = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=1, beam_width=1, merge_repeated=False)

result_merge, _, result_merge_len = decoder_merge.decode(aa, seq_len)
result_nomerge, _, result_nomerge_len = decoder_nomerge.decode(aa, seq_len)
result_merge, _, result_merge_len, merge_alignments = decoder_merge.decode(aa, seq_len)
result_nomerge, _, result_nomerge_len, nomerge_alignments = decoder_nomerge.decode(aa, seq_len)

self.assertEqual(result_merge_len[0][0], 1)
self.assertEqual(result_nomerge_len[0][0], 2)
Expand Down Expand Up @@ -73,9 +73,8 @@ def test_ctc_decoder_beam_search(self):
scorer = pytorch_ctc.Scorer()
decoder = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=5, space_index=-1, top_paths=2, beam_width=2, merge_repeated=False)

decode_result, scores, decode_len = decoder.decode(th_input, th_seq_len)
decode_result, scores, decode_len, alignments = decoder.decode(th_input, th_seq_len)

#self.assertEqual(result_merge.numpy()[0,0,:result_len[0][0]].tolist(), [1])
self.assertEqual(decode_len[0][0], 2)
self.assertEqual(decode_len[1][0], 3)
self.assertEqual(decode_result.numpy()[0,0,:decode_len[0][0]].tolist(), [1, 0])
Expand Down Expand Up @@ -116,9 +115,8 @@ def test_ctc_decoder_beam_search_different_blank_idx(self):
scorer = pytorch_ctc.Scorer()
decoder = pytorch_ctc.CTCBeamDecoder(scorer, labels, blank_index=0, space_index=-1, top_paths=2, beam_width=2, merge_repeated=False)

decode_result, scores, decode_len = decoder.decode(th_input, th_seq_len)
decode_result, scores, decode_len, alignments = decoder.decode(th_input, th_seq_len)

#self.assertEqual(result_merge.numpy()[0,0,:result_len[0][0]].tolist(), [1])
self.assertEqual(decode_len[0][0], 2)
self.assertEqual(decode_len[1][0], 3)
self.assertEqual(decode_result.numpy()[0,0,:decode_len[0][0]].tolist(), [2, 1])
Expand Down

0 comments on commit b45e950

Please sign in to comment.