Skip to content

Commit

Permalink
rdisable inplace_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Dec 12, 2023
1 parent a867f69 commit e3e37ff
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
13 changes: 5 additions & 8 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,13 @@
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h"
#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/params_sync_among_devices_pass.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
#include "paddle/phi/core/flags.h"
#include "paddle/pir/pass/pass_manager.h"

PHI_DECLARE_bool(enable_pir_in_executor);
PHI_DECLARE_bool(pir_apply_inplace_pass);

namespace paddle {
namespace {
Expand Down Expand Up @@ -810,12 +808,6 @@ bool AnalysisPredictor::PrepareExecutor() {
pir_program_ = std::move(
paddle::dialect::PdOpLowerToKernelPass(pir_program_.get(), place_));

::pir::PassManager pm_for_kernel_program(::pir::IrContext::Instance(), 3);
if (FLAGS_pir_apply_inplace_pass) {
pm_for_kernel_program.AddPass(::pir::CreateInplacePass());
}
pm_for_kernel_program.Run(pir_program_.get());

executor_->PrepareInterpreterCore(
sub_scope_, *pir_program_, execution_config);
} else {
Expand Down Expand Up @@ -1734,8 +1726,13 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
"optimization.";
} else if (config_.new_executor_enabled()) {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("inplace_op_var_pass");
} else {
LOG(INFO)
<< "Ir optimization is turned off, no ir pass will be executed.";
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/pir/drr/attr_type_uilts.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int32_t>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int64_t>,
paddle::dialect::IntArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<float>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray,
paddle::dialect::IntArrayAttribute);

Expand All @@ -66,6 +67,18 @@ struct IrAttrbuteCreator<std::vector<int32_t>> {
}
};

template <>
struct IrAttrbuteCreator<std::vector<float>> {
pir::ArrayAttribute operator()(std::vector<float> obj) const {
std::vector<pir::Attribute> attr_vec;
attr_vec.reserve(obj.size());
for (float x : obj) {
attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x));
}
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec);
}
};

template <typename T>
struct IrAttrTypeCast {
static T To(const pir::Attribute& attr) {
Expand Down Expand Up @@ -114,5 +127,17 @@ struct IrAttrTypeCast<std::vector<int64_t>> {
}
};

template <>
struct IrAttrTypeCast<std::vector<float>> {
static std::vector<float> To(const pir::Attribute& attr) {
std::vector<float> result;
auto array_attr = attr.dyn_cast<pir::ArrayAttribute>();
for (size_t i = 0; i < array_attr.size(); i++) {
result.push_back(array_attr.at(i).dyn_cast<pir::FloatAttribute>().data());
}
return result;
}
};

} // namespace drr
} // namespace pir
3 changes: 3 additions & 0 deletions paddle/fluid/pir/drr/ir_operation_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ static pir::Attribute CreateIrAttribute(const std::any& obj) {
} else if (obj.type() == typeid(std::vector<int64_t>)) {
return IrAttrbuteCreator<std::vector<int64_t>>()(
std::any_cast<std::vector<int64_t>>(obj));
} else if (obj.type() == typeid(std::vector<float>)) {
return IrAttrbuteCreator<std::vector<float>>()(
std::any_cast<std::vector<float>>(obj));
} else if (obj.type() == typeid(phi::IntArray)) {
return IrAttrbuteCreator<phi::IntArray>()(
std::any_cast<phi::IntArray>(obj));
Expand Down

0 comments on commit e3e37ff

Please sign in to comment.