-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pir]Supporting constant_folding_pass for train #60355
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
pir::PatternRewriter& rewriter) const override { // NOLINT | ||
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() | ||
<< "] op"; | ||
pir::Program new_program(rewriter.ir_context()); | ||
auto output_var_names = | ||
BuildProgramFromOperation(op, &new_program, rewriter); | ||
|
||
// execute program | ||
for (auto output_var_name : output_var_names) { | ||
exe_config_->skip_gc_vars.insert(output_var_name); | ||
} | ||
auto kernel_program = | ||
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); | ||
paddle::framework::InterpreterCore core( | ||
place_, {}, kernel_program->block(), scope_, *exe_config_); | ||
|
||
core.Run({}); | ||
auto output_var_names = RunOp(op, rewriter, place_); | ||
|
||
// ParameterOp and ConstantTensorOp should be created in the top-level block | ||
rewriter.SetInsertionPointToStart( | ||
|
@@ -236,6 +223,28 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
return true; | ||
} | ||
|
||
protected: | ||
std::vector<std::string> RunOp( | ||
pir::Operation* op, | ||
pir::PatternRewriter& rewriter, | ||
phi::Place place) const { // NOLINT | ||
pir::Program new_program(rewriter.ir_context()); | ||
auto output_var_names = | ||
BuildProgramFromOperation(op, &new_program, rewriter); | ||
|
||
// execute program | ||
for (auto output_var_name : output_var_names) { | ||
exe_config_->skip_gc_vars.insert(output_var_name); | ||
} | ||
auto kernel_program = | ||
paddle::dialect::PdOpLowerToKernelPass(&new_program, place); | ||
paddle::framework::InterpreterCore core( | ||
place, {}, kernel_program->block(), scope_, *exe_config_); | ||
|
||
core.Run({}); | ||
return output_var_names; | ||
} | ||
|
||
std::vector<std::string> BuildProgramFromOperation( | ||
pir::Operation* op, | ||
pir::Program* new_program, | ||
|
@@ -299,14 +308,76 @@ class ConstantFoldingPattern : public pir::RewritePattern { | |
return output_var_names; | ||
} | ||
|
||
private: | ||
protected: | ||
size_t* counter_; | ||
phi::Place place_; | ||
paddle::framework::Scope* scope_; | ||
paddle::framework::interpreter::ExecutionConfig* exe_config_; | ||
std::vector<std::string>* deleted_vars_; | ||
}; | ||
|
||
class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { | ||
public: | ||
ConstantFoldingTrainingPattern( | ||
pir::IrContext* context, | ||
size_t* counter, | ||
const phi::Place& place, | ||
paddle::framework::Scope* scope, | ||
paddle::framework::interpreter::ExecutionConfig* exe_config, | ||
std::vector<std::string>* deleted_vars) | ||
: ConstantFoldingPattern( | ||
context, counter, place, scope, exe_config, deleted_vars) {} | ||
|
||
bool Match(pir::Operation* op) const override { | ||
VLOG(4) << "constant_folding_training_pass applys match on [" << op->name() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. constant_folding_training_pass -> constant_folding_pass for train There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已修改 |
||
<< "] op"; | ||
if (!ConstantFoldingPattern::Match(op)) { | ||
return false; | ||
} | ||
for (uint32_t i = 0; i < op->num_operands(); i++) { | ||
// inputs must come from or constant op | ||
auto* prev_op = pir::GetDefiningOpForInput(op, i); | ||
if (!prev_op || !prev_op->isa<pir::ConstantTensorOp>()) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
void Rewrite(pir::Operation* op, | ||
pir::PatternRewriter& rewriter) const override { // NOLINT | ||
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. log同上 |
||
<< "] op"; | ||
|
||
auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 训练的执行place一定是cpu吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是,需要限定ForTrain时,传入的place是cpu place There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
我修改了代码,如果是训练,就不把执行器的place传进pattern了,直接用cpu place~ |
||
|
||
// ConstantTensorOp should be created in the top-level block | ||
rewriter.SetInsertionPointToStart( | ||
rewriter.block()->parent_program()->block()); | ||
|
||
for (uint32_t i = 0; i < op->num_results(); i++) { | ||
if (!op->result(i) || !op->result(i).type()) { | ||
continue; | ||
} | ||
std::string output_var_name = output_var_names[i]; | ||
PADDLE_ENFORCE_NOT_NULL( | ||
scope_->FindVar(output_var_name), | ||
phi::errors::InvalidArgument("Parameter var [%s] not in scope.", | ||
output_var_name)); | ||
|
||
auto constant_op = rewriter.Build<pir::ConstantTensorOp>( | ||
rewriter.tensor_name_attr(output_var_name), op->result(i).type()); | ||
constant_op->set_attribute( | ||
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); | ||
|
||
rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0)); | ||
} | ||
rewriter.EraseOp(op); | ||
VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. log同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
已修改 |
||
<< "] op"; | ||
} | ||
}; | ||
|
||
class ConstantFoldingPass : public pir::Pass { | ||
public: | ||
ConstantFoldingPass() | ||
|
@@ -332,8 +403,14 @@ class ConstantFoldingPass : public pir::Pass { | |
scope_, phi::errors::InvalidArgument("scope can not be nullptr")); | ||
|
||
pir::RewritePatternSet ps(context); | ||
ps.Add<ConstantFoldingPattern>( | ||
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); | ||
|
||
if (Has("train_mode") && Get<bool>("train_mode")) { | ||
ps.Add<ConstantFoldingTrainingPattern>( | ||
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); | ||
} else { | ||
ps.Add<ConstantFoldingPattern>( | ||
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); | ||
} | ||
patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); | ||
return true; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pattern命名可以改为ConstantFoldingPatternForTrain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好滴