Skip to content

Commit

Permalink
[PIR] Add yield instruction (PaddlePaddle#64234)
Browse files Browse the repository at this point in the history
* add yield_instruction

* rm CopyBranchOutput

* fix while_instruction
  • Loading branch information
huangjiyi authored and co63oc committed May 13, 2024
1 parent 310b334 commit f19d2d1
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ IfInstruction::IfInstruction(size_t id,
execution_config);

std::set<std::string> true_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&true_branch_block)) {
true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
}
// NOTE(zhangbo): According to the concept of control flow, child scopes
// should not control the lifecycle of parent scope variables.
for (auto value : true_outside_inputs) {
Expand All @@ -170,11 +165,6 @@ IfInstruction::IfInstruction(size_t id,
value_exec_info->NewChild(false_scope),
execution_config);
std::set<std::string> false_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&false_branch_block)) {
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
}
for (auto value : false_outside_inputs) {
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
Expand Down Expand Up @@ -245,8 +235,6 @@ void IfInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(true_branch_inter_->GetPlace());
#endif
true_branch_inter_->Run({}, false);
CopyBranchOutput(
true_branch_outputs_, output_vars_, true_branch_inter_->InnerScope());
} else {
#ifdef PADDLE_WITH_DNNL
// Executor on being destroyed clears oneDNN cache and resets
Expand All @@ -255,8 +243,6 @@ void IfInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(false_branch_inter_->GetPlace());
#endif
false_branch_inter_->Run({}, false);
CopyBranchOutput(
false_branch_outputs_, output_vars_, false_branch_inter_->InnerScope());
}
// copy output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ class IfInstruction : public InstructionBase {

PirInterpreter* false_branch_inter_ = nullptr;

std::vector<std::string> true_branch_outputs_;

std::vector<std::string> false_branch_outputs_;

// TODO(zhangbo): Currently, only the output of IfOp is included. In the
// future, need to consider how to support IfGradOp using IfOp value.
std::vector<std::string> true_skip_gc_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ PyLayerInstruction::PyLayerInstruction(
execution_config);

std::set<std::string> fwd_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&fwd_block)) {
fwd_outputs_.push_back(fwd_inter_->GetNameByValue(value));
fwd_skip_gc_names_.push_back(fwd_inter_->GetNameByValue(value));
fwd_skip_gc_names_set.insert(fwd_inter_->GetNameByValue(value));
}

// NOTE(zhangbo): According to the concept of control flow, child scopes
// should not control the lifecycle of parent scope variables.
Expand Down Expand Up @@ -166,7 +161,6 @@ void PyLayerInstruction::Run() {
paddle::platform::DontClearMKLDNNCache(fwd_inter_->GetPlace());
#endif
fwd_inter_->Run({}, false);
CopyBranchOutput(fwd_outputs_, output_vars_, fwd_inter_->InnerScope());
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class PyLayerInstruction : public InstructionBase {

PirInterpreter* fwd_inter_ = nullptr;

std::vector<std::string> fwd_outputs_;

std::vector<std::string> fwd_skip_gc_names_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/platform/onednn_helper.h"
Expand Down Expand Up @@ -128,10 +129,10 @@ WhileInstruction::WhileInstruction(
body_inter_ = std::unique_ptr<PirInterpreter>(new PirInterpreter(
place, {}, body_block_, body_scope, body_exe_info, execution_config));

auto body_block_outputs = GetYiedOpInputs(body_block_);
for (auto value : body_block_outputs) {
body_outputs_.push_back(body_inter_->GetNameByValue(value));
skip_gc_vars.insert(body_inter_->GetNameByValue(value));
if (body_block_->back().isa<pir::YieldOp>()) {
const auto& op = body_block_->back();
inner_cond_ = body_inter_->GetNameByValue(op.operand_source(0));
skip_gc_vars.insert(inner_cond_);
}
for (auto value : body_outside_inputs) {
auto name = body_inter_->GetNameByValue(value);
Expand All @@ -141,12 +142,6 @@ WhileInstruction::WhileInstruction(
body_inter_->SetSkipGcVars(skip_gc_vars);

if (VLOG_IS_ON(6)) {
std::stringstream body_outputs;
for (const auto& var_name : body_outputs_) {
body_outputs << " " << var_name;
}
VLOG(6) << "body_outputs include: " << body_outputs.str();

std::stringstream body_skip_gc_names;
for (const auto& var_name : skip_gc_vars) {
body_skip_gc_names << " " << var_name;
Expand Down Expand Up @@ -193,44 +188,10 @@ void WhileInstruction::ShareOutputsToBlockArgs() {
}
}

void WhileInstruction::ShareDatasToOutputs() {
void WhileInstruction::ShareConditionData() {
auto inner_cond_var = body_inter_->local_scope()->GetVar(inner_cond_);
cond_var_->GetMutable<phi::DenseTensor>()->ShareDataWith(
body_inter_->local_scope()
->GetVar(body_outputs_[0])
->Get<phi::DenseTensor>());
for (size_t i = 0; i < outputs_.size(); ++i) {
auto& out_var_name = body_outputs_[i + 1];
auto* out_var = body_inter_->local_scope()->GetVar(out_var_name);
VLOG(6) << "share data from " << out_var_name << " -> " << i << " output";
if (out_var->IsType<phi::DenseTensor>()) {
outputs_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
out_var->Get<phi::DenseTensor>());
VLOG(6) << "share data from " << out_var_name << "[" << out_var << "]"
<< " -> " << i << " output[" << outputs_[i] << "]";
} else if (out_var->IsType<phi::TensorArray>()) {
const auto& inner_array = out_var->Get<phi::TensorArray>();
auto* output_array = outputs_[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("unsupported type %d", out_var->Type()));
}

VLOG(6) << "done";
}

for (size_t i = 0; i < outputs_.size(); ++i) {
auto& out_var_name = body_outputs_[i + 1];
auto* out_var = body_inter_->local_scope()->GetVar(out_var_name);
if (out_var->IsType<phi::DenseTensor>()) {
// NOTE(zhangbo): Delete the input of the yield operator, except for the
// external vars of the block.
if (external_input_names_.count(out_var_name) == 0) {
VLOG(6) << "clear internel input " << out_var_name;
out_var->GetMutable<phi::DenseTensor>()->clear();
}
}
}
inner_cond_var->Get<phi::DenseTensor>());
}

void WhileInstruction::SetOutputHooks(
Expand Down Expand Up @@ -266,8 +227,8 @@ void WhileInstruction::Run() {
ShareOutputsToBlockArgs();
VLOG(6) << "while instruction interpretercore run";
body_inter_->Run({}, false);
VLOG(6) << "while instruction get value form body block";
ShareDatasToOutputs();
VLOG(6) << "while instruction get condition value form body block";
ShareConditionData();
}
VLOG(6) << "while instruction run done";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ class WhileInstruction : public InstructionBase {
// Pass argument to body_block for execution.
void ShareOutputsToBlockArgs();

// Get return value from body_block after each execution.
void ShareDatasToOutputs();
// Get condition value from body_block after each execution.
void ShareConditionData();

std::string name_{"while_instruction"};

Variable* cond_var_;
std::string inner_cond_;

std::vector<Variable*> inputs_;
std::vector<Variable*> outputs_;

std::unique_ptr<PirInterpreter> body_inter_;
std::vector<std::string> body_outputs_;
std::set<std::string> external_input_names_;

::pir::Block* body_block_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/control_flow/yield_instruction.h"

#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

namespace paddle {
namespace framework {

YieldInstruction::YieldInstruction(size_t id,
const platform::Place &place,
::pir::Operation *op,
ValueExecutionInfo *value_exe_info)
: InstructionBase(id, place), op_(op) {
VLOG(6) << "construct yield instruction";

auto parent_op = op->GetParentOp();

std::unordered_map<pir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); ++i) {
// Skip the first input (cond) when the parent op is a while op.
if (parent_op->isa<paddle::dialect::WhileOp>() && i == 0) {
continue;
}
auto in = op->operand_source(i);
inputs.emplace(in, GetValueIds(in, *value_exe_info));
input_vars_.push_back(value_exe_info->GetVarByValue(in));
}
SetInputs(inputs);

for (size_t i = 0; i < parent_op->num_results(); ++i) {
output_vars_.push_back(value_exe_info->GetVarByValue(parent_op->result(i)));
}

PADDLE_ENFORCE_EQ(
input_vars_.size(),
output_vars_.size(),
phi::errors::InvalidArgument("The number of inputs in YieldOp and "
"outputs of parent op must be equal."
"But received %d and %d.",
input_vars_.size(),
output_vars_.size()));
}

void YieldInstruction::Run() {
for (size_t i = 0; i < input_vars_.size(); ++i) {
if (input_vars_[i]->IsType<phi::DenseTensor>()) {
output_vars_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
input_vars_[i]->Get<phi::DenseTensor>());
} else if (input_vars_[i]->IsType<phi::TensorArray>()) {
const auto &inner_array = input_vars_[i]->Get<phi::TensorArray>();
auto *output_array = output_vars_[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(phi::errors::Unimplemented("unsupported type %d",
input_vars_[i]->Type()));
}
}
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"

namespace paddle {
namespace framework {
class ValueExecutionInfo;

class YieldInstruction : public InstructionBase {
public:
YieldInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

const std::string& Name() const override { return name_; }

::pir::Operation* Operation() const override { return op_; }

private:
::pir::Operation* op_;

std::string name_{"yield_instruction"};

std::vector<Variable*> input_vars_;

std::vector<Variable*> output_vars_;
};

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,6 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {
return OpFuncType::kGpuAsync;
}

std::vector<pir::Value> GetYiedOpInputs(pir::Block* block) {
std::vector<pir::Value> vec_res;

if (block && !block->empty() && block->back().isa<pir::YieldOp>()) {
auto& op = block->back();
for (size_t i = 0; i < op.num_operands(); ++i) {
vec_res.emplace_back(op.operand_source(i));
}
}
return vec_res;
}

void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
Expand Down Expand Up @@ -420,26 +408,5 @@ bool GetCondData(const phi::DenseTensor& cond) {
return cpu_cond->data<bool>()[0];
}

void CopyBranchOutput(const std::vector<std::string>& var_names,
const std::vector<Variable*>& output_vars,
Scope* inner_scope) {
for (size_t i = 0; i < var_names.size(); ++i) {
auto* inner_var = inner_scope->GetVar(var_names[i]);

if (inner_var->IsType<phi::DenseTensor>()) {
output_vars[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
inner_var->Get<phi::DenseTensor>());

} else if (inner_var->IsType<phi::TensorArray>()) {
const auto& inner_array = inner_var->Get<phi::TensorArray>();
auto* output_array = output_vars[i]->GetMutable<phi::TensorArray>();
*output_array = inner_array;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("unsupported type %d", inner_var->Type()));
}
}
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ platform::DeviceContext* ParseDeviceContext(
OpFuncType AnalyseOpFuncType(::pir::Operation* op,
const platform::Place& place);

std::vector<pir::Value> GetYiedOpInputs(pir::Block* block);

void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);
Expand All @@ -67,9 +65,5 @@ void InsertInplacedExternalInputsToOuts(

bool GetCondData(const phi::DenseTensor& cond);

void CopyBranchOutput(const std::vector<std::string>& var_names,
const std::vector<Variable*>& output_vars,
Scope* inner_scope);

} // namespace framework
} // namespace paddle
Loading

0 comments on commit f19d2d1

Please sign in to comment.