Skip to content

Commit

Permalink
Merged PR 12442: cherry pick a few improvements/fixes from Frank's br…
Browse files Browse the repository at this point in the history
…anch

Cherry pick a few improvements/fixes from Frank's branch
* Adds Frank's fix for label-based mini-batch sizing from Frank's current experimental branch.
* Also copies minor improvements and a few comments.
  • Loading branch information
emjotde authored and ugermann committed May 20, 2020
1 parent 34bc47c commit bc8b6fa
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/layers/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ struct StaticLoss {
StaticLoss(const RationalLoss& dynamic)
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}

StaticLoss operator +(const StaticLoss& other) const {
StaticLoss res(*this);
res += other;
return res;
}

StaticLoss& operator +=(const StaticLoss& other) {
loss = loss + other.loss;
count = count + other.count;
Expand Down
9 changes: 9 additions & 0 deletions src/optimizers/optimizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ void Adam::updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t r
double Tref = (double)refMBWords;

// adjust for minibatch-size changes if Adam parameters are given a reference size (else do nothing)
// Why the T/Tref factor on eta? The Adam optimizer adds an RMS-normalized gradient
// value (times learning rate) to the model. We know that for Tref, that learning rate is good.
// If we increase the batch size by (T/Tref), then without adjustment, we would still add an
// RMS-normalized gradient value. That means that the contribution of an individual label is
// now weighted down by (T/Tref). However, batch-size agnostic hyper-parameterization aims to keep
// the weight on the contribution of each label gradient invariant. Thus, we must undo that
// down-weighting, by multiplying the RMS-normalized gradient value by an additional factor
// of (T/Tref). This is implemented here by locally multiplying the learning rate
// with that factor.
double eta = eta_ * (T/Tref);
double beta1 = beta1_;
double beta2 = beta2_;
Expand Down
10 changes: 4 additions & 6 deletions src/training/graph_group_sync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch,
// If a reference is given, then at progress == mbWarmup.n (ratio=1), we would like to have refBatchLabels instead of whichever
// the actual batch size is. Since we cannot know the future actual batch sizes that will be delivered
// by the reader, we approximate them with (typicalTrgBatchWords * updateMultiplier), and scale ratio accordingly.
auto refBatchLabels = options_->get<size_t>("mini-batch-words-ref");
auto refBatchLabels = options_->get<size_t>("mini-batch-words");
if (refBatchLabels != 0) {
LOG_ONCE(info, "[scheduler] Scaling to {} reference labels, using actual-batch-word estimate of {}", refBatchLabels, typicalTrgBatchWords_);
ABORT_IF(typicalTrgBatchWords_ == 0, "Dynamic scaling with words target requires MB size to be known in words"); // happens if MB size is specified in sentences
Expand Down Expand Up @@ -338,7 +338,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
// actual model update
auto updateTrgWords =
/*if*/(options_->get<std::string>("cost-type") == "ce-sum") ?
batchTrgWords
batchTrgWords // total number of labels across all GPUs and nodes
/*else*/:
OptimizerBase::mbSizeNotProvided;
shardOpt_[idx]->update(curParam, curGrad, updateTrgWords);
Expand All @@ -350,10 +350,8 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
};

// cost across all local devices (scheduler will aggregate cross-process)
StaticLoss localLoss;
for(auto& l : localDeviceLosses) // localDeviceLosses is already summed up over delay steps
localLoss += l;

StaticLoss localLoss = std::accumulate(localDeviceLosses.begin(), localDeviceLosses.end(), StaticLoss());

// model update
if (std::isfinite(localLoss.loss) || mpi_->numMPIProcesses() > 1) { // guard against NaN (except with MPI, as this simple way could hang it)
comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices and MPI nodes into shards
Expand Down
7 changes: 6 additions & 1 deletion src/training/validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ class Validator : public ValidatorBase {
options_->set("max-length", options_->get<size_t>("valid-max-length"));
options_->set("max-length-crop", true); // @TODO: make this configureable
}
if(options_->has("valid-mini-batch"))

// @TODO: make this work with mini-batch-fit etc.
if(options_->has("valid-mini-batch")) {
options_->set("mini-batch", options_->get<size_t>("valid-mini-batch"));
options_->set("mini-batch-words", 0);
}

options_->set("mini-batch-sort", "src");
options_->set("maxi-batch", 10);
}
Expand Down

0 comments on commit bc8b6fa

Please sign in to comment.