Skip to content

Commit

Permalink
add matmul_with_packed and fix topolotical order bug (PaddlePaddle#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Mar 6, 2020
1 parent e3e570e commit eb5a106
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 49 deletions.
74 changes: 74 additions & 0 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,79 @@ void matmul(struct cinn_buffer_t *_A, struct cinn_buffer_t *_B, struct cinn_buff
ASSERT_EQ(Trim(tgt), Trim(out));
}

TEST(CodeGenC, matmul_with_packed) {
const int M = 100;
const int K = 20;
const int N = 50;
const int bn = 4;
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});

lang::Buffer packedB_buf(Float(32));
lang::Buffer C_buf(Float(32));

// TODO(Superjomn) Make sure the domain works.
auto packedB = Compute(
{N / bn, K, bn}, [&](Expr i, Expr j, Expr k) { return B(j, i * bn + k); }, "PackedB");
packedB->Bind(packedB_buf);
auto C = Compute(
{M, N, K}, [&](Expr i, Expr j, Expr k) { return A(i, k) * packedB(j / bn, k, j % bn); }, "C", 2 /*reduce axis*/);
C->Bind(C_buf);
// packedB->stage()->ComputeAt(C->stage(), 1);

// Code gen
auto funcs = Lower("matmul_with_packing", {A, B, packedB, C});
ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

Module module("module1", target);
module.Append(funcs.front());
module.Append(C_buf);
module.Append(packedB_buf);

CodeGenC codegen(target);
auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
std::cout << "codegen C:" << std::endl << out << std::endl;

auto target_out = R"ROC(
#include <cinn_runtime.h>
#include <stdio.h>
cinn_buffer_t* C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t());
cinn_buffer_t* PackedB = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t());
void matmul_with_packing(struct cinn_buffer_t *_A, struct cinn_buffer_t *_B, struct cinn_buffer_t *_PackedB, struct cinn_buffer_t *_C)
{
cinn_buffer_malloc((void*)(0), _A);
cinn_buffer_malloc((void*)(0), _B);
cinn_buffer_malloc((void*)(0), _PackedB);
cinn_buffer_malloc((void*)(0), _C);
float* A = (float*)(cinn_buffer_get_data_handle(_A));
float* B = (float*)(cinn_buffer_get_data_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* PackedB = (float*)(cinn_buffer_get_data_handle(_PackedB));
for (int32_t i = 0; (i <= 11); i += 1){
for (int32_t j = 0; (j <= 19); j += 1){
for (int32_t k = 0; (k <= 3); k += 1){
PackedB[(((i * 20) + (j * 4)) + k)] = B[((j * 50) + ((i * 4) + k))];
};
};
};
for (int32_t i = 0; (i <= 99); i += 1){
for (int32_t j = 0; (j <= 49); j += 1){
for (int32_t k = 0; (k <= 19); k += 1){
C[((i * 50) + j)] = (A[((i * 20) + k)] * PackedB[((((j / 4) * 20) + (k * 4)) + (j % 4))]);
};
};
};
}
)ROC";

ASSERT_EQ(utils::Trim(target_out), utils::Trim(out));
}

} // namespace backends
} // namespace cinn
69 changes: 31 additions & 38 deletions cinn/common/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,6 @@ namespace common {

namespace {

void TopologicalSortUtil(GraphNode *node,
std::set<GraphNode *> *visited,
std::stack<GraphNode *> *stack,
std::vector<GraphNode *> *order,
std::vector<GraphEdge *> *edge_order) {
node->VisitOnce();
if (!node->visited()) return;
CHECK(!visited->count(node)) << "duplicate visit current node";

// Mark the current node as visited.
visited->insert(node);
order->push_back(node);

for (auto &e : node->outlinks()) {
if (!visited->count(e->sink())) {
edge_order->push_back(e.get());
TopologicalSortUtil(e->sink(), visited, stack, order, edge_order);
}
}

stack->push(node);
}

std::tuple<Graph::node_order_t, Graph::edge_order_t> TopologicalSort(const std::vector<GraphNode *> &nodes) {
std::stack<GraphNode *> stack;
std::set<GraphNode *> visited; // Tell whether a node is visited
std::vector<GraphNode *> order; // nodes visited in order
std::vector<GraphEdge *> edges; // edges visited in order

for (auto *node : nodes) {
if (!visited.count(node)) {
TopologicalSortUtil(node, &visited, &stack, &order, &edges);
}
}
return std::make_tuple(std::move(order), std::move(edges));
}

void DFSSortUtil(const GraphNode *node, std::vector<GraphNode *> *order) {}

std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {
Expand Down Expand Up @@ -91,7 +54,37 @@ std::vector<GraphNode *> Graph::nodes() {
}

std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topological_order() {
return TopologicalSort(nodes());
std::vector<GraphNode *> node_order;
std::vector<GraphEdge *> edge_order;

std::deque<GraphNode *> queue;
// collect indegreee.
std::map<GraphNode *, int> indegree;
for (auto *n : nodes()) {
indegree[n] = n->inlinks().size();
}

// insert start points first.
for (auto *n : start_points()) {
queue.push_back(n);
}

// start to visit
while (!queue.empty()) {
auto *top_node = queue.front();
queue.pop_front();
node_order.push_back(top_node);

for (auto &edge : top_node->outlinks()) {
edge_order.push_back(edge.get());
auto *sink = edge->sink();
if (--indegree[sink] == 0) {
queue.push_back(sink);
}
}
}

return std::make_tuple(node_order, edge_order);
}

std::vector<GraphNode *> Graph::dfs_order() { return std::vector<GraphNode *>(); }
Expand Down
30 changes: 30 additions & 0 deletions cinn/common/graph_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ std::unique_ptr<Graph> CreateGraph0() {
return graph;
}

std::unique_ptr<Graph> CreateGraph1() {
std::unique_ptr<Graph> graph(new Graph);

auto* A = make_shared<GraphNodeWithName>("A");
auto* B = make_shared<GraphNodeWithName>("B");

graph->RegisterNode("A", A);
graph->RegisterNode("B", B);

B->LinkTo(A);

return graph;
}

TEST(Graph, basic) {
// Create nodes: A, B, C, D, E
auto graph = CreateGraph0();
Expand Down Expand Up @@ -68,5 +82,21 @@ TEST(Graph, Visualize) {
LOG(INFO) << "graph:\n" << graph->Visualize();
}

TEST(Graph, simple) {
auto graph = CreateGraph1();
Graph::node_order_t node_order;
Graph::edge_order_t edge_order;
std::tie(node_order, edge_order) = graph->topological_order();

LOG(INFO) << "graph1 " << graph->Visualize();

std::vector<GraphNode*> node_order_target({graph->RetriveNode("B"), graph->RetriveNode("A")});

ASSERT_EQ(node_order.size(), node_order_target.size());
for (int i = 0; i < node_order.size(); i++) {
EXPECT_EQ(node_order[i]->id(), node_order_target[i]->id());
}
}

} // namespace common
} // namespace cinn
6 changes: 3 additions & 3 deletions cinn/lang/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ ir::Tensor Compute(const std::vector<int> &dims,

auto unique_name = name.empty() ? Context::Global().NewName("tensor") : name;

auto op = ir::ComputeOp::Make(unique_name, "" /*tag*/, {}, fn, domain);
auto tensor = ir::_Tensor_::Make(unique_name, shape, op);
if (reduce_axis >= 0) tensor->set_reduce_axis(reduce_axis);
auto op = ir::ComputeOp::Make(unique_name, "" /*tag*/, {}, fn, domain);
auto tensor = ir::_Tensor_::Make(unique_name, shape, op);
tensor->axis = axis;
tensor->domain = domain;
if (reduce_axis >= 0) tensor->set_reduce_axis(reduce_axis);
return tensor;
}

Expand Down
4 changes: 3 additions & 1 deletion cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Te

auto stages = poly::GatherStagesInTensors(args);
auto graph = poly::CreateGraph(stages);
LOG(INFO) << "Graph:\n" << graph->Visualize();

// Create a dic for stages and tensors.
std::map<std::string, Stage*> stage_dic;
Expand Down Expand Up @@ -117,7 +118,7 @@ std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Te
CHECK(args_names.count(node->id())) << "The dependency tensor [" << node->id() << "] not in the inputs";
}

auto schedule = poly::CreateSchedule(stages);
auto schedule = poly::CreateSchedule(stages, poly::ScheduleKind::Poly);

// generate the expressions for each group.
std::vector<Expr> exprs;
Expand All @@ -126,6 +127,7 @@ std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Te
CHECK_GT(group.nodes.size(), 0) << "group is empty";
std::map<std::string, Expr> tuple_to_expr;
for (auto& node : group.nodes) {
LOG(INFO) << "graph node " << node->id();
auto& tensor = tensor_dic.at(node->id());
// NOTE here just schedule the compute node.
if (!tensor->is_compute_node()) continue;
Expand Down
4 changes: 2 additions & 2 deletions cinn/lang/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ class _Tensor_ : public ExprNode<_Tensor_> {

void set_reduce_axis(int v) {
CHECK_EQ(reduce_axis, -1) << "duplicate set reduce_axis";
CHECK(!shape.empty()) << "Shape is not set";
CHECK(!domain.empty()) << "Shape is not set";
CHECK_GE(v, 0);
CHECK_LT(v, shape.size());
CHECK_LT(v, domain.size());
reduce_axis = v;
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ std::unique_ptr<DataFlowGraph> CreateGraph(const std::vector<Stage*>& stages) {
// We removed some node in the original stages(such as placeholders), so that there might be missing of some input
// nodes, just ignore the dependence.
if (input_it != std::end(id2stage)) {
auto& input_node = id2stage.at(depend_statement);
auto& input_node = input_it->second;
input_node->LinkTo(id2stage.at(stage->id()).get());
}
}
Expand Down
8 changes: 4 additions & 4 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ Iterator Stage::Fuse(const Iterator &level0, const Iterator &level1) {
std::vector<std::string> Stage::input_statements() const {
if (!expr_.defined()) return {};
VLOG(3) << "stage " << id() << " expr: " << expr_;
auto call_exprs = ir::CollectIRNodes(expr_, [](const Expr *x) { return x->As<ir::Call>(); });
auto load_exprs = ir::CollectIRNodes(expr_, [](const Expr *x) { return x->As<ir::Load>(); });
std::set<std::string> statements;
for (auto &expr : call_exprs) {
auto call_name = expr.As<ir::Call>()->name;
if (call_name != id()) statements.insert(call_name);
for (auto &expr : load_exprs) {
auto tensor_name = ir::BufferGetTensorName(expr.As<ir::Load>()->buffer.As<ir::_Buffer_>());
if (tensor_name != id()) statements.insert(tensor_name);
}
return std::vector<std::string>(statements.begin(), statements.end());
}
Expand Down

0 comments on commit eb5a106

Please sign in to comment.