diff --git a/paddle/fluid/framework/program_converter.cc b/paddle/fluid/framework/program_converter.cc index 82739e788bba36..dd4a78a6672099 100644 --- a/paddle/fluid/framework/program_converter.cc +++ b/paddle/fluid/framework/program_converter.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" @@ -30,23 +31,23 @@ namespace framework { using paddle::experimental::ExtractPlainVector; using paddle::experimental::WrapAsScalars; -std::pair> DetectLegacyOps( +static std::unordered_set needConvertedOperators = { + "assign_value", "set_value", "set_value_grad"}; + +std::pair> DetectLegacyOps( ProgramDesc* program) { bool is_legacy_program = false; - std::unordered_map legacy_op_versions; - std::unordered_map current_op_versions; + std::unordered_multimap legacy_op_map; std::unordered_map program_op_versions; + std::unordered_map legacy_op_versions; // get *all kinds* of formats of op versions and op version map to a unified // representation before comparison can be done in a neat way if (!program->HasOpVersionMap()) { is_legacy_program = true; } else { - for (const auto& pair : - paddle::framework::compatible::get_op_version_map()) { - current_op_versions.insert( - std::make_pair(pair.first, pair.second.version_id())); - } + legacy_op_versions = + paddle::framework::compatible::pb::GetLegacyOpVersions(); const auto* _op_version_map = program->OpVersionMap(); for (int i = 0; i < _op_version_map->pair_size(); ++i) { @@ -57,23 +58,32 @@ std::pair> DetectLegacyOps( program_op_versions.insert(pair); } - for (const auto& pair : program_op_versions) { - uint32_t program_op_version = pair.second; - if (!current_op_versions.count(pair.first)) { - // this means program_op_versions is more upated than - // current_op_versions it is loading a program from future versions of - // paddle - continue; - } - uint32_t current_op_version = current_op_versions.at(pair.first); - if (program_op_version < current_op_version) { - is_legacy_program = true; - legacy_op_versions.insert( - std::make_pair(pair.first, program_op_version)); + const size_t num_blocks = program->Size(); + for (size_t i = 0; i < num_blocks; i++) { + BlockDesc* block = program->MutableBlock(i); + const size_t num_ops = block->OpSize(); + for (size_t j = 0; j < num_ops; j++) { + OpDesc* op = block->Op(static_cast(j)); + const std::string& op_type = op->Type(); + if (needConvertedOperators.find(op_type) != + needConvertedOperators.end()) { + // If an operator (program_op) is in the needConvertedOperators set, + // it indicates that the operator may need to be converted. + // Further judgement: if the operator does not exist in the + // program_op_version_map, the operator needs to be converted. + // Moreover, if the operator does exist and its program_op_version_ + // is less than or equal legacy_op_version, the operator also needs to + // be converted. + if (!program_op_versions.count(op_type) || + program_op_versions[op_type] <= legacy_op_versions[op_type]) { + is_legacy_program = true; + legacy_op_map.insert(std::make_pair(op_type, op)); + } + } } } } - return std::make_pair(is_legacy_program, legacy_op_versions); + return std::make_pair(is_legacy_program, legacy_op_map); } namespace no_scalar { @@ -288,7 +298,7 @@ void ConvertProgram(ProgramDesc* program) { auto legacy_op_results = DetectLegacyOps(program); bool is_legacy_program = legacy_op_results.first; - const std::unordered_map& legacy_op_versions = + const std::unordered_multimap& legacy_ops = legacy_op_results.second; VLOG(3) << "is_legacy_program : " << is_legacy_program; @@ -303,25 +313,15 @@ void ConvertProgram(ProgramDesc* program) { VLOG(3) << "Converting program from old(no scalar attributes) to new(with " "scalar attributes)"; - const size_t num_blocks = program->Size(); - for (size_t i = 0; i < num_blocks; i++) { - BlockDesc* block = program->MutableBlock(i); - const size_t num_ops = block->OpSize(); - for (size_t j = 0; j < num_ops; j++) { - OpDesc* op = block->Op(static_cast(j)); - const std::string op_type = op->Type(); - if (op_type == "assign_value") { - VLOG(3) << "Converting program from old to new, op_type=" << op_type; - ConvertAssignValueOp(op); - } - if (!legacy_op_versions.count(op_type)) { - continue; - } - VLOG(3) << "Converting program from old to new, op_type=" << op_type; - if (op_type == "set_value" || op_type == "set_value_grad") { - ConvertSetValueOp(op); - } + for (const auto& pair : legacy_ops) { + const std::string op_type = pair.first; + VLOG(3) << "Converting program from old to new, op_type=" << op_type; + if (op_type == "set_value" || op_type == "set_value_grad") { + ConvertSetValueOp(pair.second); + } + if (op_type == "assign_value") { + ConvertAssignValueOp(pair.second); } } }