Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Add yield instruction #64234

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ IfInstruction::IfInstruction(size_t id,
execution_config);

std::set<std::string> true_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&true_branch_block)) {
true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
}
// NOTE(zhangbo): According to the concept of control flow, child scopes
// should not control the lifecycle of parent scope variables.
for (auto value : true_outside_inputs) {
Expand All @@ -170,11 +165,6 @@ IfInstruction::IfInstruction(size_t id,
value_exec_info->NewChild(false_scope),
execution_config);
std::set<std::string> false_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&false_branch_block)) {
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
}
for (auto value : false_outside_inputs) {
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
Expand Down Expand Up @@ -245,8 +235,6 @@ void IfInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(true_branch_inter_->GetPlace());
#endif
true_branch_inter_->Run({}, false);
CopyBranchOutput(
true_branch_outputs_, output_vars_, true_branch_inter_->InnerScope());
} else {
#ifdef PADDLE_WITH_DNNL
// Executor on being destroyed clears oneDNN cache and resets
Expand All @@ -255,8 +243,6 @@ void IfInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(false_branch_inter_->GetPlace());
#endif
false_branch_inter_->Run({}, false);
CopyBranchOutput(
false_branch_outputs_, output_vars_, false_branch_inter_->InnerScope());
}
// copy output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ class IfInstruction : public InstructionBase {

PirInterpreter* false_branch_inter_ = nullptr;

std::vector<std::string> true_branch_outputs_;

std::vector<std::string> false_branch_outputs_;

// TODO(zhangbo): Currently, only the output of IfOp is included. In the
// future, need to consider how to support IfGradOp using IfOp value.
std::vector<std::string> true_skip_gc_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ PyLayerInstruction::PyLayerInstruction(
execution_config);

std::set<std::string> fwd_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&fwd_block)) {
fwd_outputs_.push_back(fwd_inter_->GetNameByValue(value));
fwd_skip_gc_names_.push_back(fwd_inter_->GetNameByValue(value));
fwd_skip_gc_names_set.insert(fwd_inter_->GetNameByValue(value));
}

// NOTE(zhangbo): According to the concept of control flow, child scopes
// should not control the lifecycle of parent scope variables.
Expand Down Expand Up @@ -166,7 +161,6 @@ void PyLayerInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(fwd_inter_->GetPlace());
#endif
fwd_inter_->Run({}, false);
CopyBranchOutput(fwd_outputs_, output_vars_, fwd_inter_->InnerScope());
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class PyLayerInstruction : public InstructionBase {

PirInterpreter* fwd_inter_ = nullptr;

std::vector<std::string> fwd_outputs_;

std::vector<std::string> fwd_skip_gc_names_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/platform/onednn_helper.h"
Expand Down Expand Up @@ -128,10 +129,10 @@ WhileInstruction::WhileInstruction(
body_inter_ = std::unique_ptr<PirInterpreter>(new PirInterpreter(
place, {}, body_block_, body_scope, body_exe_info, execution_config));

auto body_block_outputs = GetYiedOpInputs(body_block_);
for (auto value : body_block_outputs) {
body_outputs_.push_back(body_inter_->GetNameByValue(value));
skip_gc_vars.insert(body_inter_->GetNameByValue(value));
if (body_block_->back().isa<pir::YieldOp>()) {
const auto& op = body_block_->back();
inner_cond_ = body_inter_->GetNameByValue(op.operand_source(0));
skip_gc_vars.insert(inner_cond_);
}
for (auto value : body_outside_inputs) {
auto name = body_inter_->GetNameByValue(value);
Expand All @@ -141,12 +142,6 @@ WhileInstruction::WhileInstruction(
body_inter_->SetSkipGcVars(skip_gc_vars);

if (VLOG_IS_ON(6)) {
std::stringstream body_outputs;
for (const auto& var_name : body_outputs_) {
body_outputs << " " << var_name;
}
VLOG(6) << "body_outputs include: " << body_outputs.str();

std::stringstream body_skip_gc_names;
for (const auto& var_name : skip_gc_vars) {
body_skip_gc_names << " " << var_name;
Expand Down Expand Up @@ -193,44 +188,10 @@ void WhileInstruction::ShareOutputsToBlockArgs() {
}
}

void WhileInstruction::ShareDatasToOutputs() {
void WhileInstruction::ShareConditionData() {
auto inner_cond_var = body_inter_->local_scope()->GetVar(inner_cond_);
cond_var_->GetMutable<phi::DenseTensor>()->ShareDataWith(
body_inter_->local_scope()
->GetVar(body_outputs_[0])
->Get<phi::DenseTensor>());
for (size_t i = 0; i < outputs_.size(); ++i) {
auto& out_var_name = body_outputs_[i + 1];
auto* out_var = body_inter_->local_scope()->GetVar(out_var_name);
VLOG(6) << "share data from " << out_var_name << " -> " << i << " output";
if (out_var->IsType<phi::DenseTensor>()) {
outputs_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
out_var->Get<phi::DenseTensor>());
VLOG(6) << "share data from " << out_var_name << "[" << out_var << "]"
<< " -> " << i << " output[" << outputs_[i] << "]";
} else if (out_var->IsType<phi::TensorArray>()) {
const auto& inner_array = out_var->Get<phi::TensorArray>();
auto* output_array = outputs_[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("unsupported type %d", out_var->Type()));
}

VLOG(6) << "done";
}

for (size_t i = 0; i < outputs_.size(); ++i) {
auto& out_var_name = body_outputs_[i + 1];
auto* out_var = body_inter_->local_scope()->GetVar(out_var_name);
if (out_var->IsType<phi::DenseTensor>()) {
// NOTE(zhangbo): Delete the input of the yield operator, except for the
// external vars of the block.
if (external_input_names_.count(out_var_name) == 0) {
VLOG(6) << "clear internel input " << out_var_name;
out_var->GetMutable<phi::DenseTensor>()->clear();
}
}
}
inner_cond_var->Get<phi::DenseTensor>());
}

void WhileInstruction::SetOutputHooks(
Expand Down Expand Up @@ -266,8 +227,8 @@ void WhileInstruction::Run() {
ShareOutputsToBlockArgs();
VLOG(6) << "while instruction interpretercore run";
body_inter_->Run({}, false);
VLOG(6) << "while instruction get value form body block";
ShareDatasToOutputs();
VLOG(6) << "while instruction get condition value form body block";
ShareConditionData();
}
VLOG(6) << "while instruction run done";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ class WhileInstruction : public InstructionBase {
// Pass argument to body_block for execution.
void ShareOutputsToBlockArgs();

// Get return value from body_block after each execution.
void ShareDatasToOutputs();
// Get condition value from body_block after each execution.
void ShareConditionData();

std::string name_{"while_instruction"};

Variable* cond_var_;
std::string inner_cond_;

std::vector<Variable*> inputs_;
std::vector<Variable*> outputs_;

std::unique_ptr<PirInterpreter> body_inter_;
std::vector<std::string> body_outputs_;
std::set<std::string> external_input_names_;

::pir::Block* body_block_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/control_flow/yield_instruction.h"

#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

namespace paddle {
namespace framework {

YieldInstruction::YieldInstruction(size_t id,
const platform::Place &place,
::pir::Operation *op,
ValueExecutionInfo *value_exe_info)
: InstructionBase(id, place), op_(op) {
VLOG(6) << "construct yield instruction";

auto parent_op = op->GetParentOp();

std::unordered_map<pir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); ++i) {
// Skip the first input (cond) when the parent op is a while op.
if (parent_op->isa<paddle::dialect::WhileOp>() && i == 0) {
continue;
}
auto in = op->operand_source(i);
inputs.emplace(in, GetValueIds(in, *value_exe_info));
input_vars_.push_back(value_exe_info->GetVarByValue(in));
}
SetInputs(inputs);

for (size_t i = 0; i < parent_op->num_results(); ++i) {
output_vars_.push_back(value_exe_info->GetVarByValue(parent_op->result(i)));
}

PADDLE_ENFORCE_EQ(
input_vars_.size(),
output_vars_.size(),
phi::errors::InvalidArgument("The number of inputs in YieldOp and "
"outputs of parent op must be equal."
"But received %d and %d.",
input_vars_.size(),
output_vars_.size()));
}

void YieldInstruction::Run() {
for (size_t i = 0; i < input_vars_.size(); ++i) {
if (input_vars_[i]->IsType<phi::DenseTensor>()) {
output_vars_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
input_vars_[i]->Get<phi::DenseTensor>());
} else if (input_vars_[i]->IsType<phi::TensorArray>()) {
const auto &inner_array = input_vars_[i]->Get<phi::TensorArray>();
auto *output_array = output_vars_[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(phi::errors::Unimplemented("unsupported type %d",
input_vars_[i]->Type()));
}
}
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"

namespace paddle {
namespace framework {
class ValueExecutionInfo;

class YieldInstruction : public InstructionBase {
public:
YieldInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

const std::string& Name() const override { return name_; }

::pir::Operation* Operation() const override { return op_; }

private:
::pir::Operation* op_;

std::string name_{"yield_instruction"};

std::vector<Variable*> input_vars_;

std::vector<Variable*> output_vars_;
};

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,6 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {
return OpFuncType::kGpuAsync;
}

std::vector<pir::Value> GetYiedOpInputs(pir::Block* block) {
std::vector<pir::Value> vec_res;

if (block && !block->empty() && block->back().isa<pir::YieldOp>()) {
auto& op = block->back();
for (size_t i = 0; i < op.num_operands(); ++i) {
vec_res.emplace_back(op.operand_source(i));
}
}
return vec_res;
}

void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
Expand Down Expand Up @@ -420,26 +408,5 @@ bool GetCondData(const phi::DenseTensor& cond) {
return cpu_cond->data<bool>()[0];
}

void CopyBranchOutput(const std::vector<std::string>& var_names,
const std::vector<Variable*>& output_vars,
Scope* inner_scope) {
for (size_t i = 0; i < var_names.size(); ++i) {
auto* inner_var = inner_scope->GetVar(var_names[i]);

if (inner_var->IsType<phi::DenseTensor>()) {
output_vars[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
inner_var->Get<phi::DenseTensor>());

} else if (inner_var->IsType<phi::TensorArray>()) {
const auto& inner_array = inner_var->Get<phi::TensorArray>();
auto* output_array = output_vars[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("unsupported type %d", inner_var->Type()));
}
}
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ platform::DeviceContext* ParseDeviceContext(
OpFuncType AnalyseOpFuncType(::pir::Operation* op,
const platform::Place& place);

std::vector<pir::Value> GetYiedOpInputs(pir::Block* block);

void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);
Expand All @@ -67,9 +65,5 @@ void InsertInplacedExternalInputsToOuts(

bool GetCondData(const phi::DenseTensor& cond);

void CopyBranchOutput(const std::vector<std::string>& var_names,
const std::vector<Variable*>& output_vars,
Scope* inner_scope);

} // namespace framework
} // namespace paddle
Loading