diff --git a/paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc b/paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc index 303580f2761f7e..175376ac35dbf2 100644 --- a/paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc +++ b/paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc @@ -37,7 +37,8 @@ std::unique_ptr PianoGraphExecutor::CreateInputOperand( auto* node = cluster_inputs_.at(id); PADDLE_ENFORCE_EQ(node->IsVar(), true, platform::errors::InvalidArgument( - "Cluster Sub-Graph Input should be var")); + "Cluster Sub-Graph Input should be var, but %s not", + node->Name().c_str())); const auto& var_name = node->Name(); @@ -71,10 +72,12 @@ GraphNodeVec PianoGraphExecutor::SortInternalCluster() const { for (auto* n : cluster_) { PADDLE_ENFORCE_EQ(n->IsOp(), true, platform::errors::PreconditionNotMet( - "Cluster's node all should be op node")); + "Cluster's node all should be op node, but %s not", + n->Name().c_str())); PADDLE_ENFORCE_EQ(PianoOpRegistry::IsPianoOp(n->Name()), true, platform::errors::PreconditionNotMet( - "Cluster's op all should be piano op")); + "Cluster's op all should be piano op, but %s not", + n->Name().c_str())); // the op's input is var for (auto* in_var : n->inputs) { // the var's input is op diff --git a/paddle/fluid/compiler/paddle2piano/piano_graph_executor_test.cc b/paddle/fluid/compiler/paddle2piano/piano_graph_executor_test.cc index 7008f682e00711..95697976e9153c 100644 --- a/paddle/fluid/compiler/paddle2piano/piano_graph_executor_test.cc +++ b/paddle/fluid/compiler/paddle2piano/piano_graph_executor_test.cc @@ -54,12 +54,16 @@ class TestOp : public OperatorWithKernel { void InferShape(InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "test"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "test"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "test"); + OP_INOUT_CHECK(ctx->HasOutput("Res"), "Output", "Res", "test"); auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Out", in_dims); + ctx->SetOutputDim("Res", in_dims); ctx->ShareLoD("X", /*->*/ "Out"); + ctx->ShareLoD("Y", /*->*/ "Out"); } }; @@ -67,7 +71,9 @@ class TestOpMaker : public OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor), The input tensor of test op."); + AddInput("Y", "(Tensor), The input tensor of test op."); AddOutput("Out", "(Tensor), The output tensor of test op."); + AddOutput("Res", "(Tensor), The output tensor of test op."); AddComment(R"DOC(Test Operator.)DOC"); } }; @@ -144,7 +150,8 @@ std::array>, 4> CreateGraph() { auto* var5 = graph.at(3).back().get(); // insert cluster op node - framework::OpDesc op_desc_1("op1", {{"var1", {""}}}, {{"var2", {""}}}, {}); + framework::OpDesc op_desc_1("op1", {{"X", {"var1"}}}, {{"Out", {"var2"}}}, + {}); graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_1)); auto* op1 = graph.at(0).back().get(); @@ -154,8 +161,8 @@ std::array>, 4> CreateGraph() { op1->outputs.emplace_back(var2); var2->inputs.emplace_back(op1); - framework::OpDesc op_desc_3("op3", {{"var3", {""}}, {"var4", {""}}}, - {{"var5", {""}}}, {}); + framework::OpDesc op_desc_3("op3", {{"X", {"var3"}}, {"Y", {"var4"}}}, + {{"Out", {"var5"}}}, {}); graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_3)); auto* op3 = graph.at(0).back().get(); @@ -168,8 +175,8 @@ std::array>, 4> CreateGraph() { op3->outputs.emplace_back(var5); var5->inputs.emplace_back(op3); - framework::OpDesc op_desc_2("op2", {{"var2", {""}}}, - {{"var3", {""}}, {"var4", {""}}}, {}); + framework::OpDesc op_desc_2("op2", {{"X", {"var2"}}}, + {{"Out", {"var3"}}, {"Res", {"var4"}}}, {}); graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_2)); auto* op2 = graph.at(0).back().get(); diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc b/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc index 1efe5ce5e24649..d2a50501e76331 100644 --- a/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc +++ b/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc @@ -50,7 +50,13 @@ symbolization::Operand PianoOpKernelContext::GetInput( HasInput(name), true, platform::errors::NotFound("Input %s is not found in op %s.", name.c_str(), op_->Type().c_str())); - return scope_->GetOperand(name); + const std::vector& input_args = op_->Input(name); + PADDLE_ENFORCE_EQ(input_args.size(), 1, + platform::errors::InvalidArgument( + "Th input %s of op %s should has just 1 argument, " + "but here %zd", + name.c_str(), op_->Type().c_str(), input_args.size())); + return scope_->GetOperand(input_args.at(0)); } void PianoOpKernelContext::SetOutput(const std::string& name, @@ -59,7 +65,13 @@ void PianoOpKernelContext::SetOutput(const std::string& name, op_->HasOutput(name), true, platform::errors::NotFound("Output %s is not found in op %s.", name.c_str(), op_->Type().c_str())); - scope_->SetOperand(name, op); + const std::vector& output_args = op_->Output(name); + PADDLE_ENFORCE_EQ(output_args.size(), 1, + platform::errors::InvalidArgument( + "Th output %s of op %s should has just 1 argument, " + "but here %zd", + name.c_str(), op_->Type().c_str(), output_args.size())); + scope_->SetOperand(output_args.at(0), op); } } // namespace piano diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context_test.cc b/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context_test.cc index 1d6e9953ec2450..a20afacc4a3206 100644 --- a/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context_test.cc +++ b/paddle/fluid/compiler/paddle2piano/piano_op_kernel_context_test.cc @@ -159,6 +159,7 @@ TEST(PianoContextTest, basic) { auto* op = global_block->AppendOp(); op->SetType("test"); op->SetInput("X", {"IN1"}); + op->SetInput("Y", {"IN1", "IN2"}); op->SetOutput("Out", {"OUT1"}); op->SetAttr("scale", 3.14f); op->SetAttr("same", 0); @@ -166,7 +167,7 @@ TEST(PianoContextTest, basic) { // create scope and NoteBuilder PianoScope scope; symbolization::Operand op_x; - scope.SetOperand("X", op_x); + scope.SetOperand("IN1", op_x); symbolization::NoteBuilder builder("test_expand"); @@ -182,13 +183,15 @@ TEST(PianoContextTest, basic) { ASSERT_TRUE(ctx.HasInput("X")); // Operand no match for 'operator==' ASSERT_NO_THROW(ctx.GetInput("X")); - ASSERT_FALSE(ctx.HasInput("Y")); + // Input "Y" has two argument + ASSERT_TRUE(ctx.HasInput("Y")); ASSERT_ANY_THROW(ctx.GetInput("Y")); + ASSERT_FALSE(ctx.HasInput("Z")); // test output symbolization::Operand op_out; ASSERT_NO_THROW(ctx.SetOutput("Out", op_out)); - ASSERT_ANY_THROW(ctx.SetOutput("Y", op_out)); + ASSERT_ANY_THROW(ctx.SetOutput("Z", op_out)); // test attribute ASSERT_EQ(ctx.DataTypes(), TestDatatypes());