Skip to content

Commit

Permalink
Merged PR 11566: removed overstuff and understuff features
Browse files Browse the repository at this point in the history
removed overstuff and understuff features
  • Loading branch information
frankseide authored and emjotde committed Feb 15, 2020
1 parent 1044f7f commit e09f713
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 54 deletions.
6 changes: 0 additions & 6 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,12 +769,6 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
{"0"});
cli.add<bool>("--mini-batch-track-lr",
"Dynamically track mini-batch size inverse to actual learning rate (not considering lr-warmup)");
cli.add<size_t>("--mini-batch-overstuff",
"[experimental] Stuff this much more data into a minibatch, but scale down the LR and progress counter",
1);
cli.add<size_t>("--mini-batch-understuff",
"[experimental] Break each batch into this many updates",
1);
}
// clang-format on
}
Expand Down
52 changes: 6 additions & 46 deletions src/training/graph_group_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static double roundUpRatio(double ratio) {
// helper routine that handles accumulation and load-balancing of sub-batches to fill all devices
// It adds 'newBatch' to 'pendingBatches_', and if sufficient batches have been queued, then
// returns 'pendingBatches_' in 'subBatches' and resets it. If not, it returns false.
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff,
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch,
std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches) {
// The reader delivers in chunks of these sizes, according to case:
// - no dynamic MB-size scaling:
Expand Down Expand Up @@ -199,9 +199,6 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
ratio *= (double)refBatchLabels / (double)(typicalTrgBatchWords_ * updateMultiplier_);
}

// overstuff: blow up ratio by a factor, which we later factor into the learning rate
ratio *= (double)overstuff;

// round up to full batches if within a certain error margin --@BUGBUG: Not invariant w.r.t. GPU size, as ratio is relative to what fits into 1 GPU
ratio = roundUpRatio(ratio);

Expand Down Expand Up @@ -267,41 +264,18 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
void SyncGraphGroup::update(Ptr<data::Batch> newBatch) /*override*/ {
validate();

size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
if (overstuff != 1)
LOG_ONCE(info, "Overstuffing minibatches by a factor of {}", overstuff);
std::vector<Ptr<data::Batch>> subBatches;
size_t numReadBatches; // actual #batches delivered by reader, for restoring from checkpoint --@TODO: reader should checkpoint itself; should not go via the scheduler
bool gotSubBatches = tryGetSubBatches(newBatch, overstuff, subBatches, numReadBatches);
bool gotSubBatches = tryGetSubBatches(newBatch, subBatches, numReadBatches);

// not enough data yet: return right away
if (!gotSubBatches)
return;

// for testing the hypothesis that one can always go smaller. This is independent of overstuff.
size_t understuff = options_->get<size_t>("mini-batch-understuff");
if (understuff != 1)
LOG_ONCE(info, "Understuffing minibatches by a factor of {}", understuff);
if (understuff == 1)
update(subBatches, numReadBatches);
else {
std::vector<Ptr<data::Batch>> subBatches1;
for (auto& b : subBatches) {
auto bbs = b->split(understuff);
for (auto& bb : bbs)
subBatches1.push_back(bb);
}
for (size_t i = 0; i < understuff; i++) {
std::vector<Ptr<data::Batch>> subBatchRange(subBatches1.begin() + i * subBatches1.size() / understuff, subBatches1.begin() + (i+1) * subBatches1.size() / understuff);
if (!subBatchRange.empty())
update(subBatchRange, numReadBatches * (i+1) / understuff - numReadBatches * i / understuff);
}
}
update(subBatches, numReadBatches);
}

void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches) {
size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
//size_t understuff = options_->get<size_t>("mini-batch-understuff");
// determine num words for dynamic hyper-parameter adjustment
// @TODO: We can return these directly from tryGetSubBatches()
size_t batchSize = 0;
Expand All @@ -310,9 +284,6 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
batchSize += batch->size();
batchTrgWords += batch->wordsTrg();
}
// effective batch size: batch should be weighted like this. This will weight down the learning rate.
size_t effectiveBatchTrgWords = (size_t)ceil(batchTrgWords / (double)overstuff);
size_t effectiveBatchSize = (size_t)ceil(batchSize / (double)overstuff);

// Helper to access the subBatches array
auto getSubBatch = [&](size_t warp, size_t localDeviceIndex, size_t rank) -> Ptr<data::Batch> {
Expand Down Expand Up @@ -353,32 +324,21 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
auto rationalLoss = builders_[localDeviceIndex]->build(graph, subBatch);
graph->forward();

StaticLoss tempLoss = *rationalLoss; // needed for overstuff
tempLoss.loss /= (float)overstuff; // @TODO: @fseide: scale only loss? should this scale labels too?

localDeviceLosses[localDeviceIndex] += tempLoss;
localDeviceLosses[localDeviceIndex] += *rationalLoss;
graph->backward(/*zero=*/false); // (gradients are reset before we get here)
}
});
// At this point, each device on each MPI process has a gradient aggregated over a subset of the sub-batches.

// only needed for overstuff now
float div = (float)overstuff; // (note: with Adam, a constant here makes no difference)

// Update parameter shard with gradient shard
auto update = [&](size_t idx, size_t begin, size_t end) {
auto curGrad = graphs_[idx]->params()->grads()->subtensor(begin, end-begin);
auto curParam = graphs_[idx]->params()->vals()->subtensor(begin, end-begin);

if(div != 1.f) {
using namespace functional;
Element(_1 = _1 / div, curGrad); // average if overstuffed
}

// actual model update
auto updateTrgWords =
/*if*/(options_->get<std::string>("cost-type") == "ce-sum") ?
effectiveBatchTrgWords // if overstuffing then bring the count back to the original value
batchTrgWords
/*else*/:
OptimizerBase::mbSizeNotProvided;
shardOpt_[idx]->update(curParam, curGrad, updateTrgWords);
Expand All @@ -405,7 +365,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num

if(scheduler_) {
// track and log localLoss
scheduler_->update(localLoss, numReadBatches, effectiveBatchSize, effectiveBatchTrgWords, mpi_);
scheduler_->update(localLoss, numReadBatches, batchSize, batchTrgWords, mpi_);

// save intermediate model (and optimizer state) to file
if(scheduler_->saving())
Expand Down
2 changes: 1 addition & 1 deletion src/training/graph_group_sync.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
void barrier() const { mpi_->barrier(); } // (we need this several times)
void swapParamsAvg() { if (mvAvg_ && paramsAvg_.size() > 0) comm_->swapParams(paramsAvg_); } // note: must call this on all MPI ranks in parallel

bool tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
bool tryGetSubBatches(Ptr<data::Batch> newBatch, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
void update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches);

public:
Expand Down
2 changes: 1 addition & 1 deletion src/translator/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class BeamSearch {
if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only
std::vector<size_t> eosFactors;
factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors);
wordIdx = eosFactors[0];
wordIdx = (WordIndex)eosFactors[0];
} else { // without factoredVocab lemma index and word index are the same. Safe cruising.
wordIdx = trgVocab_->getEosId().toWordIndex();
}
Expand Down

0 comments on commit e09f713

Please sign in to comment.