Skip to content

Commit

Permalink
[PIR] Refine mutable attribute (PaddlePaddle#57306)
Browse files Browse the repository at this point in the history
* add code

* fix bug

* fix bug

* add code
  • Loading branch information
zhangbo9674 authored Sep 15, 2023
1 parent 358bfca commit f084ccd
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,7 @@
'ReduceIntArrayAxisInferMeta',
}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {
'SplitOp',
'SumOp',
'SplitWithNumOp',
'ConcatOp',
'MeanOp',
}

_PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'}

OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
Expand Down Expand Up @@ -376,6 +369,25 @@ def GenBuildOutputs(
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType"));
}}\n"""

CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector<int64_t> {name};
if ({name}_.owner()->info().id() == pir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = {name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
}} else if ({name}_.type().isa<pir::VectorType>()) {{
size_t {name}_size = {name}_.type().dyn_cast<pir::VectorType>().size();
{name} = std::vector<int64_t>({name}_size, -1);
}} else if ({name}_.type().isa<paddle::dialect::DenseTensorType>()) {{
size_t {name}_size = phi::product({name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
{name} = std::vector<int64_t>({name}_size, -1);
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType"));
}}\n"""

CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == pir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
Expand Down Expand Up @@ -424,30 +436,23 @@ def GenBuildOutputs(
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
if (
op_class_name
in _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE
):
build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
else:
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
else:
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
# string
elif attr_dtype[0] == "pir::StrAttribute":
build_output_str += ""
Expand Down

0 comments on commit f084ccd

Please sign in to comment.