Skip to content

Commit

Permalink
Merged PR 20933: beam & batch works for n on-factored models
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieu Hoang authored and emjotde committed Oct 13, 2021
1 parent 03fe175 commit 2d79ad0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
22 changes: 16 additions & 6 deletions src/layers/output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,24 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
}
return Logits(std::move(allLogits), factoredVocab_);
} else if(shortlist_) {
return Logits(affineOrDot(input,
shortlist_->getCachedShortWt(),
shortlist_->getCachedShortb(),
const Shape &inputShape = input->shape();
assert(inputShape[1] == 1); // time dimension always 1 for decoding
input = reshape(input, {inputShape[0], inputShape[2], 1, inputShape[3]});

Expr Wt = shortlist_->getCachedShortWt();
Expr b = shortlist_->getCachedShortb();
Expr ret = affineShortlist(input,
Wt,
b,
false,
/*transB=*/isLegacyUntransposedW ? false : true));
/*transB=*/isLegacyUntransposedW ? false : true);
const Shape &retShape = ret->shape();
assert(retShape[2] == 1); // time dimension always 1 for decoding
ret = reshape(ret, {retShape[0], 1, retShape[1], retShape[3]});
return Logits(ret);
} else {
return Logits(
affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
Expr ret = affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true);
return Logits(ret);
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/translator/beam_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx);
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
Expand All @@ -115,7 +115,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}
else if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx));
word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);

Expand Down Expand Up @@ -330,6 +330,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
auto prevBatchIdxMap = batchIdxMap; // [origBatchIdx -> currentBatchIdx] but shifted by one time step
// main loop over output time steps
for (size_t t = 0; ; t++) {
//std::cerr << "\nstep=" << t << std::endl;
ABORT_IF(origDimBatch != beams.size(), "Lost a batch entry??");
// determine beam size for next output time step, as max over still-active sentences
// E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
Expand Down
2 changes: 2 additions & 0 deletions src/translator/nth_element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
* SPDX-License-Identifier: MIT
*/

#include "common/utils.h"
#include "translator/nth_element.h"

#include <algorithm>
#include <iterator>
#include <limits>
Expand Down

0 comments on commit 2d79ad0

Please sign in to comment.