Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix bug]fix bug for program_converter #60629

Merged
merged 5 commits into from
Jan 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 41 additions & 41 deletions paddle/fluid/framework/program_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
Expand All @@ -30,23 +31,23 @@ namespace framework {
using paddle::experimental::ExtractPlainVector;
using paddle::experimental::WrapAsScalars;

std::pair<bool, std::unordered_map<std::string, uint32_t>> DetectLegacyOps(
static std::unordered_set<std::string> needConvertedOperators = {
"assign_value", "set_value", "set_value_grad"};

std::pair<bool, std::unordered_multimap<std::string, OpDesc*>> DetectLegacyOps(
ProgramDesc* program) {
bool is_legacy_program = false;
std::unordered_map<std::string, uint32_t> legacy_op_versions;
std::unordered_map<std::string, uint32_t> current_op_versions;
std::unordered_multimap<std::string, OpDesc*> legacy_op_map;
std::unordered_map<std::string, uint32_t> program_op_versions;
std::unordered_map<std::string, uint32_t> legacy_op_versions;

// get *all kinds* of formats of op versions and op version map to a unified
// representation before comparison can be done in a neat way
if (!program->HasOpVersionMap()) {
is_legacy_program = true;
} else {
for (const auto& pair :
paddle::framework::compatible::get_op_version_map()) {
current_op_versions.insert(
std::make_pair(pair.first, pair.second.version_id()));
}
legacy_op_versions =
paddle::framework::compatible::pb::GetLegacyOpVersions();

const auto* _op_version_map = program->OpVersionMap();
for (int i = 0; i < _op_version_map->pair_size(); ++i) {
Expand All @@ -57,23 +58,32 @@ std::pair<bool, std::unordered_map<std::string, uint32_t>> DetectLegacyOps(
program_op_versions.insert(pair);
}

for (const auto& pair : program_op_versions) {
uint32_t program_op_version = pair.second;
if (!current_op_versions.count(pair.first)) {
// this means program_op_versions is more upated than
// current_op_versions it is loading a program from future versions of
// paddle
continue;
}
uint32_t current_op_version = current_op_versions.at(pair.first);
if (program_op_version < current_op_version) {
is_legacy_program = true;
legacy_op_versions.insert(
std::make_pair(pair.first, program_op_version));
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(static_cast<int>(j));
const std::string& op_type = op->Type();
if (needConvertedOperators.find(op_type) !=
needConvertedOperators.end()) {
// If an operator (program_op) is in the needConvertedOperators set,
// it indicates that the operator may need to be converted.
// Further judgement: if the operator does not exist in the
// program_op_version_map, the operator needs to be converted.
// Moreover, if the operator does exist and its program_op_version_
// is less than or equal legacy_op_version, the operator also needs to
// be converted.
if (!program_op_versions.count(op_type) ||
program_op_versions[op_type] <= legacy_op_versions[op_type]) {
is_legacy_program = true;
legacy_op_map.insert(std::make_pair(op_type, op));
}
}
}
}
}
return std::make_pair(is_legacy_program, legacy_op_versions);
return std::make_pair(is_legacy_program, legacy_op_map);
}

namespace no_scalar {
Expand Down Expand Up @@ -288,7 +298,7 @@ void ConvertProgram(ProgramDesc* program) {

auto legacy_op_results = DetectLegacyOps(program);
bool is_legacy_program = legacy_op_results.first;
const std::unordered_map<std::string, uint32_t>& legacy_op_versions =
const std::unordered_multimap<std::string, OpDesc*>& legacy_ops =
legacy_op_results.second;

VLOG(3) << "is_legacy_program : " << is_legacy_program;
Expand All @@ -303,25 +313,15 @@ void ConvertProgram(ProgramDesc* program) {

VLOG(3) << "Converting program from old(no scalar attributes) to new(with "
"scalar attributes)";
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(static_cast<int>(j));
const std::string op_type = op->Type();

if (op_type == "assign_value") {
VLOG(3) << "Converting program from old to new, op_type=" << op_type;
ConvertAssignValueOp(op);
}
if (!legacy_op_versions.count(op_type)) {
continue;
}
VLOG(3) << "Converting program from old to new, op_type=" << op_type;
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
for (const auto& pair : legacy_ops) {
const std::string op_type = pair.first;
VLOG(3) << "Converting program from old to new, op_type=" << op_type;
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(pair.second);
}
if (op_type == "assign_value") {
ConvertAssignValueOp(pair.second);
}
}
}
Expand Down