Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] add python api for if #60895

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,36 @@ void IfInstruction::CopyBranchOutput(const std::vector<std::string>& var_names,
}

void IfInstruction::Run() {
DeviceContext().Wait();
if (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
bool cond = true;
if (cond_var_->IsType<phi::DenseTensor>()) {
auto& cond_tensor = cond_var_->Get<phi::DenseTensor>();
if (paddle::platform::is_cpu_place(cond_tensor.place())) {
cond = cond_tensor.data<bool>()[0];
} else {
// when platform::is_gpu_place(cond.place()) or
// platform::is_xpu_place(cond.place()) is true
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_CUSTOM_DEVICE)
DeviceContext().Wait();
phi::DenseTensor cpu_cond;
paddle::framework::TensorCopySync(
cond_tensor, platform::CPUPlace(), &cpu_cond);
cond = cpu_cond.data<bool>()[0];
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"This version of PaddlePaddle does NOT support GPU/XPU but got "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does NOT support GPU/XPU ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个目前是跟while的判定逻辑对齐的。

"GPU/XPU tensor Cond in WhileOp. Please compile WITH_GPU or "
"WITH_XPU option."));
#endif
}
} else if (cond_var_->IsType<VariableRefArray>()) {
auto& cond_array = cond_var_->Get<VariableRefArray>();
cond = std::all_of(
cond_array.begin(), cond_array.end(), [](const Variable* t) {
return t->Get<phi::DenseTensor>().numel() != 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t 一定是 DenseTensor么,最好加一个检查?另外这里的判断逻辑就是要用tensor 的 numel 来判断么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a3475f9201270d0a14670a9a97523111
这个是复制的原始的condition_block.cc中的执行逻辑。我理解是只支持为densorTensor的情况。如果不是densorTensor, 在Getphi::DenseTensor()这句话里面会有错误提示。

});
}
if (cond) {
true_branch_inter_->Run({}, false);
CopyBranchOutput(true_branch_outputs_, true_branch_inter_);
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/api_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ApiBuilder {
void SetParameter(const std::string& name,
std::unique_ptr<pir::Parameter>&& parameter);

std::shared_ptr<pir::Builder> GetBuilder() { return builder_; }
const std::shared_ptr<pir::Builder>& GetBuilder() const { return builder_; }

const pir::InsertionPoint& GetCurrentInsertionPoint() const {
return builder_->insertion_point();
Expand Down
32 changes: 21 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,14 @@ void IfOp::VerifyRegion() {
1u,
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
(*this)->region(0).size()));
if ((*this)->region(0).front().size() > 0) {
auto &true_last_op = (*this)->region(0).front().back();
if ((*this)->num_results() != 0) {
auto &true_block = (*this)->region(0).front();
PADDLE_ENFORCE_GT(
true_block.size(),
0u,
phi::errors::PreconditionNotMet(
"The true block must have at least one op yield op."));
auto &true_last_op = true_block.back();
PADDLE_ENFORCE_EQ(true,
true_last_op.isa<pir::YieldOp>(),
phi::errors::PreconditionNotMet(
Expand All @@ -228,15 +234,19 @@ void IfOp::VerifyRegion() {
phi::errors::PreconditionNotMet(
"The size of last of true block op's input must be "
"equal to IfOp's outputs num."));
}
VLOG(4) << "Start Verifying false branch.";
PADDLE_ENFORCE_EQ(
(*this)->region(1).size(),
1u,
phi::errors::PreconditionNotMet("The size %d of false_region must be 1.",
(*this)->region(0).size()));
if ((*this)->region(1).front().size() > 0) {
auto &false_last_op = (*this)->region(1).front().back();
VLOG(4) << "Start Verifying false branch.";
PADDLE_ENFORCE_EQ((*this)->region(1).size(),
1u,
phi::errors::PreconditionNotMet(
"The size %d of false_region must be 1.",
(*this)->region(0).size()));
auto &false_block = (*this)->region(1).front();
PADDLE_ENFORCE_GT(
false_block.size(),
0u,
phi::errors::PreconditionNotMet(
"The false block must have at least one op yield op."));
auto &false_last_op = false_block.back();
PADDLE_ENFORCE_EQ(true,
false_last_op.isa<pir::YieldOp>(),
phi::errors::PreconditionNotMet(
Expand Down
21 changes: 0 additions & 21 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1124,27 +1124,6 @@ void HandleForIfOp(
"[%d]'s input of [%s] op MUST in map pair", 0, op_item->name()));
auto new_cond = map_value_pair->at(old_cond);

// NOTE(zhangbo): IfOp's input cond should be a cpu type.
AllocatedDenseTensorType new_cond_type =
new_cond.type().dyn_cast<AllocatedDenseTensorType>();
if (new_cond_type) {
if (new_cond_type.place().GetType() == phi::AllocationType::GPU) {
auto out_type = AllocatedDenseTensorType::get(
ctx, phi::CPUPlace(), old_cond.type().dyn_cast<DenseTensorType>());
phi::KernelKey kernel_key(
phi::Backend::GPU, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL);
new_cond = AddPlaceTransferOp(new_cond,
out_type,
new_cond_type.place(),
phi::CPUPlace(),
kernel_key,
block);
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("IfOp onlu support DenseTensorType"));
}

// Create IfOp and insert to kernel dialect program
pir::Builder builder(ctx, block);
auto old_ifop = op_item->dyn_cast<IfOp>();
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using paddle::pybind::PyIfOp;
using paddle::pybind::PyWhileOp;
using pir::Block;
using pir::Builder;
using pir::CombineOp;
using pir::Operation;
using pir::Program;
using pir::Region;
Expand All @@ -60,13 +61,19 @@ void BindIfOp(py::module* m) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
});
m->def("build_if_op", [](const std::vector<Value>& cond) {
auto& builder = ApiBuilder::Instance().GetBuilder();
auto new_cond = builder->Build<CombineOp>(cond).out();
return PyIfOp(builder->Build<IfOp>(new_cond, std::vector<Type>{}));
});
py::class_<PyIfOp> if_op(*m, "IfOp", R"DOC(
The PyIfOp is a encapsulation of IfOp. Compared with ifOp, it provides an additional 'update_output' interface.
The 'update_output' interface will construct a new IfOp operation to replace its underlying IfOp. In the process, the original
IfOp will be destroyed. In order to avoid the risk of memory used in python side, We encapsulate PyIfOp to python api.
)DOC");
if_op.def("true_block", &PyIfOp::true_block, return_value_policy::reference)
.def("false_block", &PyIfOp::false_block, return_value_policy::reference)
.def("cond", &PyIfOp::cond)
.def("update_output", &PyIfOp::UpdateOutput)
.def("as_operation", &PyIfOp::operation, return_value_policy::reference)
.def("results", [](PyIfOp& self) -> py::list {
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/layer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Parameter,
dtype_is_floating,
in_dygraph_mode,
in_pir_mode,
)
from .layer_helper_base import LayerHelperBase
from .param_attr import ParamAttr
Expand Down Expand Up @@ -132,6 +133,8 @@ def append_bias_op(self, input_var, dim_start=1, dim_end=None):
b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True
)
if in_pir_mode():
return input_var + b
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op(
type='elementwise_add',
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/layer_helper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle

from . import core, unique_name
from .data_feeder import convert_dtype
from .framework import (
Variable,
_current_expected_place,
Expand Down Expand Up @@ -359,6 +360,8 @@ def create_parameter(
# set global dtype
if not dtype:
dtype = self.__dtype
if isinstance(dtype, core.DataType):
dtype = convert_dtype(dtype)
if is_bias:
suffix = 'b'
default_initializer = (
Expand Down
34 changes: 25 additions & 9 deletions python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
default_main_program,
in_dygraph_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
name_scope,
program_guard,
static_only,
Expand Down Expand Up @@ -191,10 +192,17 @@ def fc_base(
name=None,
):
helper = LayerHelper("fc", **locals())
check_type(input, 'input', (list, tuple, Variable), 'fc')
check_type(
input, 'input', (list, tuple, Variable, paddle.pir.Value), 'fc'
)
if isinstance(input, (list, tuple)):
for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc')
check_type(
input_x,
'input[' + str(i) + ']',
(Variable, paddle.pir.Value),
'fc',
)
dtype = helper.input_dtype()
check_dtype(
dtype, 'input', ['float16', 'uint16', 'float32', 'float64'], 'fc'
Expand All @@ -210,17 +218,25 @@ def fc_base(
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False
)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="mul",
inputs={"X": input_var, "Y": w},
outputs={"Out": tmp},
attrs={"x_num_col_dims": num_flatten_dims, "y_num_col_dims": 1},
)
if in_pir_mode():
tmp = paddle.matmul(input_var, w)
else:
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="mul",
inputs={"X": input_var, "Y": w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1,
},
)
mul_results.append(tmp)

if len(mul_results) == 1:
pre_bias = mul_results[0]
elif in_pir_mode():
pre_bias = paddle.add_n(mul_results)
else:
pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op(
Expand Down
65 changes: 59 additions & 6 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,51 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return super().__exit__(exc_type, exc_val, exc_tb)


class If:
'''
**If**

If is an operator that bind two blocks (true_block and false_block) to a specific condition,
According to the condition, the corresponding block will be executed.

Args:
cond (Value): A value whose data type is bool controlling which block is executed.

Examples:
.. code-block:: python

>>> import paddle
>>> from paddle.static.nn.control_flow import ConditionalBlock

>>> label = paddle.rand([1])
>>> limit = paddle.ones([1]) * 0.5
>>> cond = paddle.less_than(x=label, y=limit)
>>> if_op = If(cond)
>>> with if_op.true_block():
... pass
>>> with if_op.false_block():
... pass
'''

def __init__(self, cond):
if not isinstance(cond, list):
check_variable_and_dtype(cond, 'cond', ['bool'], 'static.nn.If')
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError(
"condition expected shape as [1], but given shape as {}.".format(
list(cond.shape)
)
)
self.if_op = build_if_op(cond)
self.cond_var = self.if_op.cond()

def true_block(self):
return self.if_op.true_block()

def false_block(self):
return self.if_op.false_block()


class ConditionalBlock:
'''
**ConditionalBlock**
Expand Down Expand Up @@ -208,13 +253,23 @@ class ConditionalBlock:
'''

def __init__(self, inputs, is_scalar_condition=False, name=None):
for each_input in inputs:
check_type(each_input, "input", Variable, "ConditionalBlock")
self.inputs = inputs
if in_pir_mode():
if is_scalar_condition and len(inputs) != 1:
raise TypeError(
"For ConditionalBlock Api, Only support one input while is_scalar_condition is True"
)
return
else:
for each_input in inputs:
check_type(each_input, "input", Variable, "ConditionalBlock")

self.is_scalar_condition = is_scalar_condition
self.helper = LayerHelper('conditional_block', name=name)

def block(self):
if in_pir_mode():
return If(self.inputs).true_block()
return ConditionalBlockGuard(self)

def complete(self):
Expand Down Expand Up @@ -1244,9 +1299,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
return None
true_output = None
false_output = None
check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
check_type(name, "name", (str, type(None)), "base.layers.cond")
if in_pir_mode():
check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
check_type(name, "name", (str, type(None)), "base.layers.cond")
if_op = build_if_op(pred)
if true_fn is not None:
if not callable(true_fn):
Expand All @@ -1267,8 +1322,6 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
with if_op.false_block():
false_output = false_fn()
else:
check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
check_type(name, "name", (str, type(None)), "base.layers.cond")
helper = LayerHelper('cond', **locals())
copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper)
if true_fn is not None:
Expand Down
3 changes: 0 additions & 3 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle
Expand Down Expand Up @@ -252,7 +251,6 @@ def setUp(self):

self.nested_for_loop_func = nested_for_loop_dyfunc

@test_legacy_and_pt_and_pir
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
Expand All @@ -268,7 +266,6 @@ def test_loop_vars(self):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])

@test_legacy_and_pt_and_pir
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
Expand Down
7 changes: 4 additions & 3 deletions test/legacy_test/test_conditional_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def test_forward(self):
data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32')
data.stop_gradient = False
cond = ConditionalBlock(inputs=[data])
out = paddle.tensor.create_tensor(dtype='float32')
out = paddle.tensor.fill_constant(
[10, 10], dtype='float32', value=0.0
)
out.stop_gradient = False
with cond.block():
hidden = paddle.static.nn.fc(x=data, size=10)
paddle.assign(hidden, out)
Expand All @@ -43,15 +46,13 @@ def test_forward(self):
x = np.random.random(size=(10, 1)).astype('float32')

outs = exe.run(main_program, feed={'X': x}, fetch_list=[out])[0]
print(outs)
loss = paddle.mean(out)
append_backward(loss=loss)
outs = exe.run(
main_program,
feed={'X': x},
fetch_list=[main_program.block(0).var(data.name + "@GRAD")],
)[0]
print(outs)


class TestConditionalBlockOpInferShape(unittest.TestCase):
Expand Down