Skip to content

Commit

Permalink
actually save the merge file
Browse files Browse the repository at this point in the history
  • Loading branch information
emjotde authored and ugermann committed May 20, 2020
1 parent c95676e commit 71cc43a
Showing 1 changed file with 1 addition and 82 deletions.
83 changes: 1 addition & 82 deletions src/training/graph_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,88 +55,7 @@ class GraphGroup {
Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
Ptr<models::ICriterionFunction> model,
const std::vector<Ptr<Vocab>>& vocabs,
double multiplier = 1.) {
auto stats = New<data::BatchStats>();

size_t numFiles = options_->get<bool>("tsv", false)
? options_->get<size_t>("tsv-fields")
: options_->get<std::vector<std::string>>("train-sets").size();

// Initialize first batch to step size
size_t first = options_->get<size_t>("mini-batch-fit-step");

// Increase batch size and sentence length by this step size
size_t step = options_->get<size_t>("mini-batch-fit-step");

size_t maxLength = options_->get<size_t>("max-length");
maxLength = (size_t)(std::ceil(maxLength / (float)step) * step);

// this should be only one class label per line on input, hence restricting length to 1
std::vector<size_t> localMaxes(numFiles, maxLength);
auto inputTypes = options_->get<std::vector<std::string>>("input-types", {});
for(int i = 0; i < inputTypes.size(); ++i)
if(inputTypes[i] == "class")
localMaxes[i] = 1;

size_t maxBatch = 512;
bool fits = true;
while(fits) {
std::vector<size_t> lengths(numFiles, first);
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
lengths[j] = std::min(lengths[j], localMaxes[j]);

auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_);
auto cost = model->build(graph, batch);
fits = graph->fits();
if(fits)
maxBatch *= 2;
}

// Do a binary search for maxmimum batch size that fits into given workspace memory
// for a tested sentence length.
for(size_t i = step; i <= maxLength; i += step) {
size_t start = 1;
size_t end = maxBatch;

std::vector<size_t> lengths(numFiles, i);
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
lengths[j] = std::min(lengths[j], localMaxes[j]);
fits = true;

do {
size_t current = (start + end) / 2;
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_);
auto cost = model->build(graph, batch);
fits = graph->fits();

LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);

if(fits) {
stats->add(batch, multiplier);
start = current + 1;
} else {
end = current - 1;
}
} while(end - start > step);

maxBatch = start;
}
return stats;
}

void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling
typicalTrgBatchWords_ = typicalTrgBatchWords;
}
};

/**
* Base class for multi-node versions of GraphGroups.
*/
class MultiNodeGraphGroupBase : public GraphGroup {
using Base = GraphGroup;

protected:
Ptr<IMPIWrapper> mpi_; // all MPI-like communication goes through this
double multiplier = 1.);

void setTypicalTrgBatchWords(size_t typicalTrgBatchWords);
};
Expand Down

0 comments on commit 71cc43a

Please sign in to comment.