From f7728123b10d8496e86b186ddea54d5bfe18a1be Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Wed, 24 Jan 2024 06:00:31 +0000 Subject: [PATCH] update --- .../fluid/pir/dialect/op_generator/api_gen.py | 8 ++++---- paddle/fluid/pir/drr/ir_operation_factory.cc | 15 ++++----------- .../pir/transforms/constant_folding_pass.cc | 3 +-- .../replace_fetch_with_shadow_output_pass.cc | 2 +- .../transforms/transform_general_functions.cc | 4 ++-- .../fluid/pybind/manual_static_op_function.h | 2 +- paddle/fluid/pybind/op_function_common.cc | 10 +++------- paddle/pir/core/builtin_op.cc | 18 +++++++++--------- paddle/pir/core/ir_context.h | 1 - paddle/pir/core/op_info.h | 1 - test/cpp/pir/core/block_operand_test.cc | 6 ++---- test/cpp/pir/tools/test_op.cc | 2 +- test/cpp/pir/tools/test_op.h | 2 +- test/ir/pir/test_symbol_overload.py | 4 ++-- test/legacy_test/test_expand_v2_op.py | 4 ++-- test/legacy_test/test_reshape_op.py | 4 ++-- test/legacy_test/test_where_op.py | 4 ++-- 17 files changed, 37 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 57450de0f5cf90..f3f57f78b8d04a 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -137,13 +137,13 @@ optional_{name} = {name}; }}""" -OPTIONAL_OPRESULT_OUTPUT_TEMPLATE = """ +OPTIONAL_VALUE_OUTPUT_TEMPLATE = """ paddle::optional optional_{name}; if (!IsEmptyValue({op_name}_op.result({index}))) {{ optional_{name} = paddle::make_optional({op_name}_op.result({index})); }}""" -OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE = """ +OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE = """ paddle::optional> optional_{name}; if (!IsEmptyValue({op_name}_op.result({index}))) {{ auto optional_{name}_slice_op = ApiBuilder::Instance().GetBuilder()->Build({op_name}_op.result({index})); @@ -423,13 +423,13 @@ def _gen_handle_optional_outputs(self, op_info, op_name): continue if self._is_optional_output(op_info, name): if VECTOR_TYPE in type: - ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format( + ret += OPTIONAL_VECTOR_VALUE_OUTPUT_TEMPLATE.format( name=name, op_name=op_name, index=i, ) else: - ret += OPTIONAL_OPRESULT_OUTPUT_TEMPLATE.format( + ret += OPTIONAL_VALUE_OUTPUT_TEMPLATE.format( name=name, op_name=op_name, index=i, diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index 1d6fa80f31605e..85dad3ca297318 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -35,10 +35,7 @@ void OperationFactory::RegisterManualOpCreator() { const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( - inputs[0].dyn_cast(), - inputs[1].dyn_cast(), - inputs[2].dyn_cast(), - attrs); + inputs[0], inputs[1], inputs[2], attrs); }); RegisterOperationCreator( "pd_op.fused_gemm_epilogue_grad", @@ -46,11 +43,7 @@ void OperationFactory::RegisterManualOpCreator() { const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( - inputs[0].dyn_cast(), - inputs[1].dyn_cast(), - inputs[2].dyn_cast(), - inputs[3].dyn_cast(), - attrs); + inputs[0], inputs[1], inputs[2], inputs[3], attrs); }); RegisterOperationCreator("builtin.combine", [](const std::vector& inputs, @@ -64,8 +57,8 @@ void OperationFactory::RegisterManualOpCreator() { const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( - inputs[0].dyn_cast(), - inputs[1].dyn_cast(), + inputs[0], + inputs[1], attrs.at("bias").dyn_cast().data(), attrs.at("bias_after_scale").dyn_cast().data()); }); diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 3bfda8e00d636f..009ec08e7523e3 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -346,8 +346,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { prev_op->name())); } } else { - op_inputs.push_back( - op->operand_source(i).dyn_cast() /*nullptr*/); + op_inputs.push_back(nullptr); } } diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc index 8029cfc9ddbf5e..15fa2bf089deec 100644 --- a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc @@ -29,7 +29,7 @@ class ReplaceFetchWithShadowOutputPattern paddle::dialect::FetchOp op, pir::PatternRewriter& rewriter) const override { // NOLINT rewriter.Build( - op->operand_source(0).dyn_cast(), + op->operand_source(0), op->attributes().at("name").dyn_cast().AsString()); rewriter.EraseOp(op); return true; diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index 9b12aea67955a3..38d08b1b4bbb6e 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -61,7 +61,7 @@ void GetUsedExternalValueImpl( namespace pir { std::string GetParameterNameFromValue(pir::Value value) { - pir::Operation* owner = value.dyn_cast().owner(); + pir::Operation* owner = value.defining_op(); std::string name; if (owner->isa()) { pir::ParameterOp op = owner->dyn_cast(); @@ -104,7 +104,7 @@ Operation* GetDefiningOpForInput(const Operation* op, uint32_t index) { index < op->num_operands() && op->operand_source(index), true, phi::errors::InvalidArgument("Intput operand's index must be valid.")); - return op->operand_source(index).dyn_cast().owner(); + return op->operand_source(index).defining_op(); } std::vector> GetUseOpsForOutput( diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index a15147156eff2b..9fdbdd547908b0 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -52,7 +52,7 @@ static PyObject *static_api_set_parameter(PyObject *self, VLOG(6) << "Add set_parameter op into program"; VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); - // Get OpResult from args + // Get Value from args PyObject *parameter_obj = PyTuple_GET_ITEM(args, 0); auto parameter = CastPyArg2Value(parameter_obj, "parameter", 0); diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index b6dd77553d4d4f..41e0f3e5a38fbd 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -865,18 +865,14 @@ void CastPyArg2AttrValues(PyObject* obj, Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; for (Py_ssize_t i = 0; i < len; i++) { - // TODO(xiongkun): judge OpResult or Value; + // TODO(xiongkun): judge Value; item = PyList_GetItem(obj, i); ::pybind11::detail::instance* inst = (::pybind11::detail::instance*)item; // NOLINT void** vh = inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[0]; - ::pir::OpResult* opresult = reinterpret_cast<::pir::OpResult*>(vh[0]); - if (opresult->impl() == nullptr) { - results.emplace_back(pir::Value(nullptr)); - } else { - results.emplace_back(pir::Value(opresult->Value::impl())); - } + ::pir::Value* value = reinterpret_cast<::pir::Value*>(vh[0]); + results.emplace_back(pir::Value(value->impl())); } } else { PADDLE_THROW(platform::errors::InvalidType( diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index 193d789b53b658..dc4a6c9de906a3 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -253,8 +253,8 @@ void SliceOp::Build(Builder &builder, void SliceOp::PassStopGradients(OperationArgument &argument, int index) { std::vector outs_stop_gradient( 1, pir::BoolAttribute::get(pir::IrContext::Instance(), true)); - if (auto input = argument.inputs[0].dyn_cast()) { - auto *defining_op = input.owner(); + if (auto input = argument.inputs[0]) { + auto *defining_op = input.defining_op(); if (defining_op && defining_op->isa()) { IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName), "Required CombineOp must have attribute %s", @@ -274,8 +274,8 @@ void SliceOp::RefreshStopGradients() { std::vector outs_stop_gradient( 1, pir::BoolAttribute::get(pir::IrContext::Instance(), true)); auto index = attribute("index").dyn_cast().data(); - if (auto input = (*this)->operand_source(0).dyn_cast()) { - auto *defining_op = input.owner(); + if (auto input = (*this)->operand_source(0)) { + auto *defining_op = input.defining_op(); if (defining_op && defining_op->isa()) { IR_ENFORCE(defining_op->HasAttribute(kStopGradientAttrName), "Required CombineOp must have attribute %s", @@ -350,8 +350,8 @@ void SplitOp::Build(Builder &builder, void SplitOp::PassStopGradients(OperationArgument &argument) { std::vector defaut_stop_gradients(argument.output_types.size(), true); - if (auto input = argument.inputs[0].dyn_cast()) { - auto *defining_op = input.owner(); + if (auto input = argument.inputs[0]) { + auto *defining_op = input.defining_op(); if (defining_op && defining_op->isa()) { IR_ENFORCE(argument.output_types.size(), defining_op->num_operands(), @@ -391,8 +391,8 @@ void SplitOp::PassStopGradients(OperationArgument &argument) { void SplitOp::RefreshStopGradients() { std::vector default_stop_gradients((*this)->num_results(), true); - if (auto input = (*this)->operand_source(0).dyn_cast()) { - auto *defining_op = input.owner(); + if (auto input = (*this)->operand_source(0)) { + auto *defining_op = input.defining_op(); if (defining_op && defining_op->isa()) { IR_ENFORCE((*this)->num_results(), defining_op->num_operands(), @@ -403,7 +403,7 @@ void SplitOp::RefreshStopGradients() { for (uint32_t i = 0; i < defining_op->num_operands(); ++i) { auto value = defining_op->operand_source(i); if (!value) continue; - auto *operand_defining_op = value.dyn_cast().owner(); + auto *operand_defining_op = value.defining_op(); if (operand_defining_op->HasAttribute(kStopGradientAttrName)) { auto attrs = operand_defining_op->attribute(kStopGradientAttrName) .dyn_cast() diff --git a/paddle/pir/core/ir_context.h b/paddle/pir/core/ir_context.h index f2686573cc67d9..b6f84ae8e3574b 100644 --- a/paddle/pir/core/ir_context.h +++ b/paddle/pir/core/ir_context.h @@ -30,7 +30,6 @@ class TypeId; class Dialect; class OpInfo; class Type; -class OpResult; class Attribute; class Operation; class InterfaceValue; diff --git a/paddle/pir/core/op_info.h b/paddle/pir/core/op_info.h index 6ca26114011c4d..f4faf76d9c822f 100644 --- a/paddle/pir/core/op_info.h +++ b/paddle/pir/core/op_info.h @@ -21,7 +21,6 @@ namespace pir { class OpInfoImpl; class IrContext; -class OpResult; class Type; class Attribute; class Dialect; diff --git a/test/cpp/pir/core/block_operand_test.cc b/test/cpp/pir/core/block_operand_test.cc index 4538255143f724..238e9e6cdd2e9e 100644 --- a/test/cpp/pir/core/block_operand_test.cc +++ b/test/cpp/pir/core/block_operand_test.cc @@ -40,8 +40,7 @@ TEST(block_operand_test, type_block) { region.push_back(block_3); builder.SetInsertionPointToBlockEnd(block_1); - auto op1 = - builder.Build(std::vector{}, block_2); + auto op1 = builder.Build(std::vector{}, block_2); EXPECT_TRUE(block_2->HasOneUse()); EXPECT_FALSE(block_2->use_empty()); @@ -55,8 +54,7 @@ TEST(block_operand_test, type_block) { EXPECT_EQ(iter_curr->owner(), op1); builder.SetInsertionPointToBlockEnd(block_3); - auto op3 = - builder.Build(std::vector{}, block_1); + auto op3 = builder.Build(std::vector{}, block_1); block_operand = op3->block_operand(0); block_operand.set_source(block_2); EXPECT_EQ(block_2, block_operand.source()); diff --git a/test/cpp/pir/tools/test_op.cc b/test/cpp/pir/tools/test_op.cc index cb2bf74293103d..79e415f26c23e2 100644 --- a/test/cpp/pir/tools/test_op.cc +++ b/test/cpp/pir/tools/test_op.cc @@ -23,7 +23,7 @@ void RegionOp::Build(pir::Builder &builder, pir::OperationArgument &argument) { void BranchOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &target_operands, + const std::vector &target_operands, pir::Block *target) { argument.AddInputs(target_operands.begin(), target_operands.end()); argument.AddSuccessor(target); diff --git a/test/cpp/pir/tools/test_op.h b/test/cpp/pir/tools/test_op.h index 253584595609be..156d92d16019be 100644 --- a/test/cpp/pir/tools/test_op.h +++ b/test/cpp/pir/tools/test_op.h @@ -49,7 +49,7 @@ class BranchOp : public pir::Op { static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &target_operands, + const std::vector &target_operands, pir::Block *target); void VerifySig() const; }; diff --git a/test/ir/pir/test_symbol_overload.py b/test/ir/pir/test_symbol_overload.py index 7847ac55085554..64470ef7909f4f 100644 --- a/test/ir/pir/test_symbol_overload.py +++ b/test/ir/pir/test_symbol_overload.py @@ -73,7 +73,7 @@ def forward(self, x, y): return z1, z2, z3, z4 -class TestOpresultSymbol(unittest.TestCase): +class TestValueSymbol(unittest.TestCase): def setUp(self): np.random.seed(2023) self.shape_x = [2, 1024, 1024] @@ -133,7 +133,7 @@ def test_symbol_overload(self): self.assertEqual(ops_ref, ops) -class TestOpresultCompareSymbol(unittest.TestCase): +class TestValueCompareSymbol(unittest.TestCase): def setUp(self): np.random.seed(2023) self.shape_x = [2, 1024, 1024] diff --git a/test/legacy_test/test_expand_v2_op.py b/test/legacy_test/test_expand_v2_op.py index 78027701661099..71d8fbb955feb6 100644 --- a/test/legacy_test/test_expand_v2_op.py +++ b/test/legacy_test/test_expand_v2_op.py @@ -541,8 +541,8 @@ def test_check_output(self): self.check_output(check_prim=True) -class TestExpandPirOpResultListShape(unittest.TestCase): - def test_opresult_list_shape(self): +class TestExpandPirValueListShape(unittest.TestCase): + def test_value_list_shape(self): with paddle.pir_utils.IrGuard(): x = paddle.static.data('x', [1, 3]) shape = [2, paddle.full([], 4)] diff --git a/test/legacy_test/test_reshape_op.py b/test/legacy_test/test_reshape_op.py index 903b00a246da9c..a3d923d1db6ec9 100755 --- a/test/legacy_test/test_reshape_op.py +++ b/test/legacy_test/test_reshape_op.py @@ -734,8 +734,8 @@ def test_static(self): self.assertEqual(result[3].shape, (1,)) -class TestReshapePirOpResultListShape(unittest.TestCase): - def test_opresult_list_shape(self): +class TestReshapePirValueListShape(unittest.TestCase): + def test_value_list_shape(self): with paddle.pir_utils.IrGuard(): x = paddle.static.data( 'x', diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 6f64ff15f45b97..fa94c8d874b786 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -798,11 +798,11 @@ def test_Variable(): self.assertRaises(TypeError, test_Variable) - def test_OpResult(): + def test_Value(): with paddle.pir_utils.IrGuard(): paddle.where(cond_i, x_i, y_i) - self.assertRaises(TypeError, test_OpResult) + self.assertRaises(TypeError, test_Value) def test_type(): x = paddle.static.data(name='x', shape=[-1, 4], dtype='bool')