Skip to content

Commit

Permalink
fix op kernel context get and set operand bug (PaddlePaddle#47)
Browse files Browse the repository at this point in the history
* fix op kernel context get and set operand bug

* fix print lld to zd
  • Loading branch information
thisjiang authored Sep 16, 2021
1 parent a8bc997 commit 7a0e3c7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 13 deletions.
9 changes: 6 additions & 3 deletions paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ std::unique_ptr<PianoScope> PianoGraphExecutor::CreateInputOperand(
auto* node = cluster_inputs_.at(id);
PADDLE_ENFORCE_EQ(node->IsVar(), true,
platform::errors::InvalidArgument(
"Cluster Sub-Graph Input should be var"));
"Cluster Sub-Graph Input should be var, but %s not",
node->Name().c_str()));

const auto& var_name = node->Name();

Expand Down Expand Up @@ -71,10 +72,12 @@ GraphNodeVec PianoGraphExecutor::SortInternalCluster() const {
for (auto* n : cluster_) {
PADDLE_ENFORCE_EQ(n->IsOp(), true,
platform::errors::PreconditionNotMet(
"Cluster's node all should be op node"));
"Cluster's node all should be op node, but %s not",
n->Name().c_str()));
PADDLE_ENFORCE_EQ(PianoOpRegistry::IsPianoOp(n->Name()), true,
platform::errors::PreconditionNotMet(
"Cluster's op all should be piano op"));
"Cluster's op all should be piano op, but %s not",
n->Name().c_str()));
// the op's input is var
for (auto* in_var : n->inputs) {
// the var's input is op
Expand Down
17 changes: 12 additions & 5 deletions paddle/fluid/compiler/paddle2piano/piano_graph_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,26 @@ class TestOp : public OperatorWithKernel {

void InferShape(InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "test");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "test");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "test");
OP_INOUT_CHECK(ctx->HasOutput("Res"), "Output", "Res", "test");

auto in_dims = ctx->GetInputDim("X");

ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Res", in_dims);
ctx->ShareLoD("X", /*->*/ "Out");
ctx->ShareLoD("Y", /*->*/ "Out");
}
};

class TestOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of test op.");
AddInput("Y", "(Tensor), The input tensor of test op.");
AddOutput("Out", "(Tensor), The output tensor of test op.");
AddOutput("Res", "(Tensor), The output tensor of test op.");
AddComment(R"DOC(Test Operator.)DOC");
}
};
Expand Down Expand Up @@ -144,7 +150,8 @@ std::array<std::vector<std::unique_ptr<Node>>, 4> CreateGraph() {
auto* var5 = graph.at(3).back().get();

// insert cluster op node
framework::OpDesc op_desc_1("op1", {{"var1", {""}}}, {{"var2", {""}}}, {});
framework::OpDesc op_desc_1("op1", {{"X", {"var1"}}}, {{"Out", {"var2"}}},
{});
graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_1));
auto* op1 = graph.at(0).back().get();

Expand All @@ -154,8 +161,8 @@ std::array<std::vector<std::unique_ptr<Node>>, 4> CreateGraph() {
op1->outputs.emplace_back(var2);
var2->inputs.emplace_back(op1);

framework::OpDesc op_desc_3("op3", {{"var3", {""}}, {"var4", {""}}},
{{"var5", {""}}}, {});
framework::OpDesc op_desc_3("op3", {{"X", {"var3"}}, {"Y", {"var4"}}},
{{"Out", {"var5"}}}, {});
graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_3));
auto* op3 = graph.at(0).back().get();

Expand All @@ -168,8 +175,8 @@ std::array<std::vector<std::unique_ptr<Node>>, 4> CreateGraph() {
op3->outputs.emplace_back(var5);
var5->inputs.emplace_back(op3);

framework::OpDesc op_desc_2("op2", {{"var2", {""}}},
{{"var3", {""}}, {"var4", {""}}}, {});
framework::OpDesc op_desc_2("op2", {{"X", {"var2"}}},
{{"Out", {"var3"}}, {"Res", {"var4"}}}, {});
graph.at(0).emplace_back(framework::ir::CreateNodeForTest(&op_desc_2));
auto* op2 = graph.at(0).back().get();

Expand Down
16 changes: 14 additions & 2 deletions paddle/fluid/compiler/paddle2piano/piano_op_kernel_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ symbolization::Operand PianoOpKernelContext::GetInput(
HasInput(name), true,
platform::errors::NotFound("Input %s is not found in op %s.",
name.c_str(), op_->Type().c_str()));
return scope_->GetOperand(name);
const std::vector<std::string>& input_args = op_->Input(name);
PADDLE_ENFORCE_EQ(input_args.size(), 1,
platform::errors::InvalidArgument(
"Th input %s of op %s should has just 1 argument, "
"but here %zd",
name.c_str(), op_->Type().c_str(), input_args.size()));
return scope_->GetOperand(input_args.at(0));
}

void PianoOpKernelContext::SetOutput(const std::string& name,
Expand All @@ -59,7 +65,13 @@ void PianoOpKernelContext::SetOutput(const std::string& name,
op_->HasOutput(name), true,
platform::errors::NotFound("Output %s is not found in op %s.",
name.c_str(), op_->Type().c_str()));
scope_->SetOperand(name, op);
const std::vector<std::string>& output_args = op_->Output(name);
PADDLE_ENFORCE_EQ(output_args.size(), 1,
platform::errors::InvalidArgument(
"Th output %s of op %s should has just 1 argument, "
"but here %zd",
name.c_str(), op_->Type().c_str(), output_args.size()));
scope_->SetOperand(output_args.at(0), op);
}

} // namespace piano
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,15 @@ TEST(PianoContextTest, basic) {
auto* op = global_block->AppendOp();
op->SetType("test");
op->SetInput("X", {"IN1"});
op->SetInput("Y", {"IN1", "IN2"});
op->SetOutput("Out", {"OUT1"});
op->SetAttr("scale", 3.14f);
op->SetAttr("same", 0);

// create scope and NoteBuilder
PianoScope scope;
symbolization::Operand op_x;
scope.SetOperand("X", op_x);
scope.SetOperand("IN1", op_x);

symbolization::NoteBuilder builder("test_expand");

Expand All @@ -182,13 +183,15 @@ TEST(PianoContextTest, basic) {
ASSERT_TRUE(ctx.HasInput("X"));
// Operand no match for 'operator=='
ASSERT_NO_THROW(ctx.GetInput("X"));
ASSERT_FALSE(ctx.HasInput("Y"));
// Input "Y" has two argument
ASSERT_TRUE(ctx.HasInput("Y"));
ASSERT_ANY_THROW(ctx.GetInput("Y"));
ASSERT_FALSE(ctx.HasInput("Z"));

// test output
symbolization::Operand op_out;
ASSERT_NO_THROW(ctx.SetOutput("Out", op_out));
ASSERT_ANY_THROW(ctx.SetOutput("Y", op_out));
ASSERT_ANY_THROW(ctx.SetOutput("Z", op_out));

// test attribute
ASSERT_EQ(ctx.DataTypes(), TestDatatypes());
Expand Down

0 comments on commit 7a0e3c7

Please sign in to comment.