-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged PR 11929: Move around code to make later comparison with FP16 …
…code easier This does not introduce any new functionality, just moves code around, so that future PRs are easier to compare. Moving old GraphGroup code to training/deprecated. Once it is clear there is nothing in there that's worth saving, this will be deleted. Replace -Ofast with -O3 and make sure ffinite-math is turned off.
- Loading branch information
Showing
17 changed files
with
107 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include "training/graph_group.h" | ||
|
||
namespace marian { | ||
|
||
GraphGroup::GraphGroup(Ptr<Options> options) : options_(options), opt_(Optimizer(options)) {} | ||
|
||
void GraphGroup::validate() { | ||
ABORT_IF(finalized_, "Training has already finished."); | ||
} | ||
|
||
void GraphGroup::finalize() { | ||
finalized_ = true; | ||
} | ||
|
||
Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph, | ||
Ptr<models::ICriterionFunction> model, | ||
const std::vector<Ptr<Vocab>>& vocabs, | ||
double multiplier) { | ||
auto stats = New<data::BatchStats>(); | ||
|
||
size_t numFiles = options_->get<std::vector<std::string>>("train-sets").size(); | ||
|
||
// Initialize first batch to step size | ||
size_t first = options_->get<size_t>("mini-batch-fit-step"); | ||
|
||
// Increase batch size and sentence length by this step size | ||
size_t step = options_->get<size_t>("mini-batch-fit-step"); | ||
|
||
size_t maxLength = options_->get<size_t>("max-length"); | ||
maxLength = (size_t)(std::ceil(maxLength / (float)step) * step); | ||
|
||
// this should be only one class label per line on input, hence restricting length to 1 | ||
std::vector<size_t> localMaxes(numFiles, maxLength); | ||
auto inputTypes = options_->get<std::vector<std::string>>("input-types", {}); | ||
for(int i = 0; i < inputTypes.size(); ++i) | ||
if(inputTypes[i] == "class") | ||
localMaxes[i] = 1; | ||
|
||
size_t maxBatch = 512; | ||
bool fits = true; | ||
while(fits) { | ||
std::vector<size_t> lengths(numFiles, first); | ||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions | ||
lengths[j] = std::min(lengths[j], localMaxes[j]); | ||
|
||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_); | ||
auto cost = model->build(graph, batch); | ||
fits = graph->fits(); | ||
if(fits) | ||
maxBatch *= 2; | ||
} | ||
|
||
// Do a binary search for maxmimum batch size that fits into given workspace memory | ||
// for a tested sentence length. | ||
for(size_t i = step; i <= maxLength; i += step) { | ||
size_t start = 1; | ||
size_t end = maxBatch; | ||
|
||
std::vector<size_t> lengths(numFiles, i); | ||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions | ||
lengths[j] = std::min(lengths[j], localMaxes[j]); | ||
fits = true; | ||
|
||
do { | ||
size_t current = (start + end) / 2; | ||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_); | ||
auto cost = model->build(graph, batch); | ||
fits = graph->fits(); | ||
|
||
LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits); | ||
|
||
if(fits) { | ||
stats->add(batch, multiplier); | ||
start = current + 1; | ||
} else { | ||
end = current - 1; | ||
} | ||
} while(end - start > step); | ||
|
||
maxBatch = start; | ||
} | ||
return stats; | ||
} | ||
|
||
void GraphGroup::setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling | ||
typicalTrgBatchWords_ = typicalTrgBatchWords; | ||
} | ||
|
||
} |
Oops, something went wrong.