Skip to content

Commit

Permalink
[Dy2St] Use ShadowOutputOp to get dy2st output (PaddlePaddle#60363)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Dec 28, 2023
1 parent 180ded5 commit 875fbfb
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ void HandleForSpecialOp(pir::Operation* op,
// change opreand name to param_name
auto orig_name = value_exe_info->GetValue2VarName().at(value);

if (var_name == orig_name) {
return;
}

if (value_exe_info->GetScope()->FindVar(var_name) != nullptr) {
const_cast<Scope*>(value_exe_info->GetScope())->EraseVars({var_name});
VLOG(1) << "var " << var_name << " has been removed from scope";
Expand Down
44 changes: 22 additions & 22 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1057,14 +1057,14 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
std::make_pair(associated_array_key, associated_array_value));
}

void AppendSetParameter(Program *forward_program,
void AppendShadowOutput(Program *forward_program,
const pir::OpResult &result,
const std::string &name,
size_t start_point) {
pir::IrContext *ctx = pir::IrContext::Instance();
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name", pir::StrAttribute::get(ctx, name)},
{"output_name", pir::StrAttribute::get(ctx, name)},
};
pir::Operation *operation =
pir::Operation::Create({result}, attribute_map, {}, op_info);
Expand All @@ -1077,7 +1077,7 @@ void AppendSetParameter(Program *forward_program,
}
}

int AppendSetParameters(Program *forward_program,
int AppendShadowOutputs(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
Expand All @@ -1086,9 +1086,9 @@ int AppendSetParameters(Program *forward_program,

for (const auto &result : outputs_op_result) {
if (!added_op_result.count(result) || IsFakeOpResult(result)) {
std::string parameter_name = name_prefix + std::to_string(counter);
AppendSetParameter(
forward_program, result, parameter_name, start_point + counter);
std::string shadow_output_name = name_prefix + std::to_string(counter);
AppendShadowOutput(
forward_program, result, shadow_output_name, start_point + counter);
counter += 1;
added_op_result.insert(result);
}
Expand Down Expand Up @@ -1204,20 +1204,20 @@ SplitedResult SplitForwardBackward(
if (v.impl() == nullptr) {
return;
}
// NOTE(Aurelius84): we should skip insert SetParameterOp repeatly by
// NOTE(Aurelius84): we should skip insert ShadowOutputOp repeatly by
// calling SplitForwardBackward multi-times.
std::string parameter_name =
std::string shadow_output_name =
std::string("output_") + std::to_string(counter);
std::unordered_set<pir::Value> inserted_value;
for (auto it = forward_program->block()->rbegin();
it != forward_program->block()->rend();
++it) {
if (it->isa<pir::SetParameterOp>()) {
if (it->isa<pir::ShadowOutputOp>()) {
auto out_name =
it->attribute<pir::StrAttribute>("parameter_name").AsString();
if (out_name == parameter_name) {
it->attribute<pir::StrAttribute>("output_name").AsString();
if (out_name == shadow_output_name) {
VLOG(4) << out_name
<< " has been inserted SetParameterOp, skip it now.";
<< " has been inserted ShadowOutputOp, skip it now.";
return;
}

Expand All @@ -1228,9 +1228,9 @@ SplitedResult SplitForwardBackward(
if (inserted_value.count(forward_value_map[v])) {
return;
}
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name", pir::StrAttribute::get(ctx, parameter_name)},
{"output_name", pir::StrAttribute::get(ctx, shadow_output_name)},
};
pir::Operation *operation = pir::Operation::Create(
{forward_value_map[v]}, attribute_map, {}, op_info);
Expand All @@ -1245,9 +1245,9 @@ SplitedResult SplitForwardBackward(
if (v.impl() == nullptr) {
return;
}
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name",
{"output_name",
pir::StrAttribute::get(
ctx, std::string("output_") + std::to_string(counter))},
};
Expand Down Expand Up @@ -1372,10 +1372,10 @@ pir::Type CreateSelectedRowsTypeByDenseTensor(pir::Type dense_tensor_type) {
}
}

void ResetParameterName(pir::Operation *op, const std::string &name) {
void ResetShadowOutputName(pir::Operation *op, const std::string &name) {
pir::IrContext *ctx = pir::IrContext::Instance();
if (op->isa<pir::SetParameterOp>()) {
op->set_attribute("parameter_name", pir::StrAttribute::get(ctx, name));
if (op->isa<pir::ShadowOutputOp>()) {
op->set_attribute("output_name", pir::StrAttribute::get(ctx, name));
}
}

Expand Down Expand Up @@ -1410,9 +1410,9 @@ std::map<int, int> GetOpInplaceInfo(const pir::Operation *op) {
void BindUtils(pybind11::module *m) {
m->def("clone_program", CloneProgram);
m->def("get_op_inplace_info", GetOpInplaceInfo);
m->def("reset_parameter_name", ResetParameterName);
m->def("reset_shadow_output_name", ResetShadowOutputName);
m->def("split_program", SplitForwardBackward);
m->def("append_set_parameters", AppendSetParameters);
m->def("append_shadow_outputs", AppendShadowOutputs);
m->def("fake_op_result", FakeOpResult);
m->def("is_fake_op_result", IsFakeOpResult);
m->def("get_current_insertion_point", []() -> PyInsertionPoint {
Expand Down
35 changes: 20 additions & 15 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def union(self, x, y):
self.father[father_x] = father_y

def find_root(self, x):
if not self.father.__contains__(x):
if x not in self.father:
self.father[x] = x
if self.father[x].is_same(x):
return x
Expand Down Expand Up @@ -135,24 +135,29 @@ def _get_value_name_map_from_program(cls, program):
ret = ValueDict()
ret[fake_op_result()] = "FakeVar"
for op in program.global_block().ops:
if op.name() == "pd_op.data":
ret[op.result(0)] = op.attrs()["name"]
if op.name() == "builtin.set_parameter":
ret[op.operand(0).source()] = op.attrs()["parameter_name"]
if op.name() == "builtin.parameter":
elif op.name() == "builtin.parameter":
ret[op.result(0)] = op.attrs()["parameter_name"]
elif op.name() == "builtin.shadow_output":
ret[op.operand(0).source()] = op.attrs()["output_name"]
elif op.name() == "pd_op.data":
ret[op.result(0)] = op.attrs()["name"]
return ret

@classmethod
def _get_name_defining_op(cls, program, value):
for op in program.global_block().ops:
if op.name() == "pd_op.data":
if op.name() == "builtin.set_parameter":
if value.is_same(op.operand(0).source()):
return op
elif op.name() == "builtin.parameter":
if value.is_same(op.result(0)):
return op
if op.name() == "builtin.set_parameter":
elif op.name() == "builtin.shadow_output":
if value.is_same(op.operand(0).source()):
return op
if op.name() == "builtin.parameter":
elif op.name() == "pd_op.data":
if value.is_same(op.result(0)):
return op
return None
Expand Down Expand Up @@ -291,7 +296,7 @@ def _forward_backward_program(self):
def program_attr(self):
assert (
self.finish_pass is False
), "program_attr() is called by PartialProgramLayer, don't call it matually, use program_name_attr instead."
), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
# can't apply pass after call this function.
self.finish_pass = True
fwd_map = {
Expand Down Expand Up @@ -346,7 +351,7 @@ def has_name(value):
if has_name(ufset.find_root(value)):
name_defining_op = self._get_name_defining_op(program, value)
if name_defining_op:
paddle.core.pir.reset_parameter_name(
paddle.core.pir.reset_shadow_output_name(
name_defining_op, value2name[ufset.find_root(value)]
)

Expand Down Expand Up @@ -384,8 +389,8 @@ class PirPassContext:
"""

INPUT_OP_NAME = "pd_op.data"
PARM_OP_NAME = "builtin.parameter"
OUTPUT_OP_NAME = "builtin.set_parameter"
PARAM_OP_NAME = "builtin.parameter"
OUTPUT_OP_NAME = "builtin.shadow_output"

@classmethod
def apply(cls, runable_program, build_strategy):
Expand Down Expand Up @@ -419,7 +424,7 @@ def _prepare_attr(cls, program):
op_name = op.name()
if op_name == cls.INPUT_OP_NAME:
inputs.append(op.result(0))
elif op_name == cls.PARM_OP_NAME:
elif op_name == cls.PARAM_OP_NAME:
params.append(op.result(0))
elif op_name == cls.OUTPUT_OP_NAME:
outputs.append(op.operand(0).source())
Expand Down Expand Up @@ -546,7 +551,7 @@ def origin_runable_program(self):
inputs = list(self._inputs.var_list)
outputs = list(self._outputs.var_list)
params = self._param_values
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
self._origin_main_program,
outputs,
len(self._origin_main_program.global_block().ops),
Expand Down Expand Up @@ -796,7 +801,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
dtype=out_op_result.dtype,
)
forward_outputs_grads.append(value)
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
program,
forward_outputs_grads,
len(program.global_block().ops),
Expand Down Expand Up @@ -861,7 +866,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
)
)
backward_end_op_index = len(program.global_block().ops)
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
program,
output_grads_to_append,
backward_end_op_index,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/pir_dy2static/parameter_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get(self, program, value):
return None
root_var = inplace_dict[value]
saved = []
while inplace_dict.__contains__(root_var):
while root_var in inplace_dict:
saved.append(root_var)
root_var = inplace_dict[root_var]
for var in saved:
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_tensor_memcpy_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -69,7 +68,7 @@ def _run(self):
x2 = paddle.jit.to_static(tensor_copy_to_cuda)(x1)
return x1.place, x2.place, x2.numpy()

@test_legacy_and_pt
@test_legacy_and_pt_and_pir
def test_tensor_cuda_on_default_cpu(self):
if not paddle.is_compiled_with_cuda():
return
Expand Down

0 comments on commit 875fbfb

Please sign in to comment.