Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pir]Supporting constant_folding_pass for train #60355

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 94 additions & 17 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
auto output_var_names = RunOp(op, rewriter, place_);

// ParameterOp and ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
Expand Down Expand Up @@ -236,6 +223,28 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return true;
}

protected:
std::vector<std::string> RunOp(
pir::Operation* op,
pir::PatternRewriter& rewriter,
phi::Place place) const { // NOLINT
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place);
paddle::framework::InterpreterCore core(
place, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
return output_var_names;
}

std::vector<std::string> BuildProgramFromOperation(
pir::Operation* op,
pir::Program* new_program,
Expand Down Expand Up @@ -299,14 +308,76 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return output_var_names;
}

private:
protected:
size_t* counter_;
phi::Place place_;
paddle::framework::Scope* scope_;
paddle::framework::interpreter::ExecutionConfig* exe_config_;
std::vector<std::string>* deleted_vars_;
};

class ConstantFoldingTrainingPattern : public ConstantFoldingPattern {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern命名可以改为ConstantFoldingPatternForTrain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern命名可以改为ConstantFoldingPatternForTrain

好滴

public:
ConstantFoldingTrainingPattern(
pir::IrContext* context,
size_t* counter,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: ConstantFoldingPattern(
context, counter, place, scope, exe_config, deleted_vars) {}

bool Match(pir::Operation* op) const override {
VLOG(4) << "constant_folding_training_pass applys match on [" << op->name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constant_folding_training_pass -> constant_folding_pass for train

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constant_folding_training_pass -> constant_folding_pass for train

已修改

<< "] op";
if (!ConstantFoldingPattern::Match(op)) {
return false;
}
for (uint32_t i = 0; i < op->num_operands(); i++) {
// inputs must come from or constant op
auto* prev_op = pir::GetDefiningOpForInput(op, i);
if (!prev_op || !prev_op->isa<pir::ConstantTensorOp>()) {
return false;
}
}
return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log同上

<< "] op";

auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练的执行place一定是cpu吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是,需要限定ForTrain时,传入的place是cpu place

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是,需要限定ForTrain时,传入的place是cpu place
我修改了代码,如果是训练,就不把执行器的place传进pattern了,直接用cpu place~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我修改了代码,如果是训练,就不把执行器的place传进pattern了,直接用cpu place~

我修改了代码,如果是训练,就不把执行器的place传进pattern了,直接用cpu place~


// ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
rewriter.block()->parent_program()->block());

for (uint32_t i = 0; i < op->num_results(); i++) {
if (!op->result(i) || !op->result(i).type()) {
continue;
}
std::string output_var_name = output_var_names[i];
PADDLE_ENFORCE_NOT_NULL(
scope_->FindVar(output_var_name),
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
output_var_name));

auto constant_op = rewriter.Build<pir::ConstantTensorOp>(
rewriter.tensor_name_attr(output_var_name), op->result(i).type());
constant_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0));
}
rewriter.EraseOp(op);
VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log同上

已修改

<< "] op";
}
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass()
Expand All @@ -332,8 +403,14 @@ class ConstantFoldingPass : public pir::Pass {
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));

pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);

if (Has("train_mode") && Get<bool>("train_mode")) {
ps.Add<ConstantFoldingTrainingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
} else {
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
}
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}
Expand Down
31 changes: 29 additions & 2 deletions test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,10 @@ void BuildConstantFoldingProgram(pir::Program *program,
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace());

auto op1 = builder.Build<pir::ParameterOp>("a", dense_tensor_dtype);
auto op2 = builder.Build<pir::ParameterOp>("b", dense_tensor_dtype);
auto op1 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("a"),
dense_tensor_dtype);
auto op2 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("b"),
dense_tensor_dtype);

auto op3 =
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
Expand Down Expand Up @@ -493,6 +495,31 @@ TEST(constant_folding, ConstantFolding) {
EXPECT_EQ(program.block()->size(), 2u);
}

TEST(constant_folding, ConstantFolding_Train) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();

pir::Program program(ctx);
paddle::framework::Scope scope;
BuildConstantFoldingProgram(&program, ctx, &scope);

pir::PassManager pm(ctx);
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope);
constant_folding_pass->Set("train_mode", new bool(true));

pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
EXPECT_EQ(program.block()->size(), 4u);
}

void BuildConcatProgram(pir::Program *program, pir::IrContext *ctx) {
pir::Builder builder = pir::Builder(ctx, program->block());
auto x = builder
Expand Down