From cec454fd90198869a32d67b99f3542f26ec84960 Mon Sep 17 00:00:00 2001 From: zyt1024 <42999008+zyt1024@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:47:35 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90fix=20bug=E3=80=91fix=20bug=20for=20pr?= =?UTF-8?q?ogram=5Fconverter=20=20(#61051)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix convert bug * modify code style --- paddle/fluid/framework/program_converter.cc | 62 ++++++++++----------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/framework/program_converter.cc b/paddle/fluid/framework/program_converter.cc index dd4a78a6672099..48d45277dfffdf 100644 --- a/paddle/fluid/framework/program_converter.cc +++ b/paddle/fluid/framework/program_converter.cc @@ -43,42 +43,36 @@ std::pair> DetectLegacyOps( // get *all kinds* of formats of op versions and op version map to a unified // representation before comparison can be done in a neat way - if (!program->HasOpVersionMap()) { - is_legacy_program = true; - } else { - legacy_op_versions = - paddle::framework::compatible::pb::GetLegacyOpVersions(); + legacy_op_versions = paddle::framework::compatible::pb::GetLegacyOpVersions(); - const auto* _op_version_map = program->OpVersionMap(); - for (int i = 0; i < _op_version_map->pair_size(); ++i) { - auto pair = - std::make_pair(_op_version_map->pair(i).op_name(), - static_cast( - _op_version_map->pair(i).op_version().version())); - program_op_versions.insert(pair); - } + const auto* _op_version_map = program->OpVersionMap(); + for (int i = 0; i < _op_version_map->pair_size(); ++i) { + auto pair = std::make_pair( + _op_version_map->pair(i).op_name(), + static_cast(_op_version_map->pair(i).op_version().version())); + program_op_versions.insert(pair); + } - const size_t num_blocks = program->Size(); - for (size_t i = 0; i < num_blocks; i++) { - BlockDesc* block = program->MutableBlock(i); - const size_t num_ops = block->OpSize(); - for (size_t j = 0; j < num_ops; j++) { - OpDesc* op = block->Op(static_cast(j)); - const std::string& op_type = op->Type(); - if (needConvertedOperators.find(op_type) != - needConvertedOperators.end()) { - // If an operator (program_op) is in the needConvertedOperators set, - // it indicates that the operator may need to be converted. - // Further judgement: if the operator does not exist in the - // program_op_version_map, the operator needs to be converted. - // Moreover, if the operator does exist and its program_op_version_ - // is less than or equal legacy_op_version, the operator also needs to - // be converted. - if (!program_op_versions.count(op_type) || - program_op_versions[op_type] <= legacy_op_versions[op_type]) { - is_legacy_program = true; - legacy_op_map.insert(std::make_pair(op_type, op)); - } + const size_t num_blocks = program->Size(); + for (size_t i = 0; i < num_blocks; i++) { + BlockDesc* block = program->MutableBlock(i); + const size_t num_ops = block->OpSize(); + for (size_t j = 0; j < num_ops; j++) { + OpDesc* op = block->Op(static_cast(j)); + const std::string& op_type = op->Type(); + if (needConvertedOperators.find(op_type) != + needConvertedOperators.end()) { + // If an operator (program_op) is in the needConvertedOperators set, + // it indicates that the operator may need to be converted. + // Further judgement: if the operator does not exist in the + // program_op_version_map, the operator needs to be converted. + // Moreover, if the operator does exist and its program_op_version_ + // is less than or equal legacy_op_version, the operator also needs to + // be converted. + if (!program_op_versions.count(op_type) || + program_op_versions[op_type] <= legacy_op_versions[op_type]) { + is_legacy_program = true; + legacy_op_map.insert(std::make_pair(op_type, op)); } } }