Skip to content

Commit

Permalink
add matmul vectorize (PaddlePaddle#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Mar 21, 2020
1 parent ff82b53 commit 6e6d79e
Show file tree
Hide file tree
Showing 24 changed files with 321 additions and 115 deletions.
8 changes: 8 additions & 0 deletions cinn/backends/_x86_builtin_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_
inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); }
// @}

//! mul
// @{
inline __m256 cinn_avx256_mul(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); }
inline __m256d cinn_avx256_mul(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); }
inline __m512 cinn_avx512_mul(const __m512& a, const __m512& b) { return _mm512_mul_ps(a, b); }
inline __m512d cinn_avx512_mul(const __m512d& a, const __m512d& b) { return _mm512_mul_pd(a, b); }
// @}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// )END Predefined utilities in CINN
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
9 changes: 1 addition & 8 deletions cinn/backends/codegen_c_x86.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,7 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
NOT_IMPLEMENTED
}
} else {
auto *load_n = op->As<ir::Load>();

if (load_n) {
Visit(load_n);
return;
}

NOT_IMPLEMENTED
Print(*op);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <glog/logging.h>

#include <algorithm>
#include <functional>
#include <list>
#include <map>
Expand All @@ -13,7 +14,6 @@
#include <unordered_map>
#include <vector>

#include <algorithm>
#include "cinn/common/object.h"
#include "cinn/common/shared.h"

Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void IRMutator<T>::Visit(const IfThenElse *expr, T op) {
auto *node = op->template As<IfThenElse>();
IRVisitorBase<void, T>::Visit(&node->condition, &node->condition);
IRVisitorBase<void, T>::Visit(&node->true_case, &node->true_case);
IRVisitorBase<void, T>::Visit(&node->false_case, &node->false_case);
if (node->false_case.defined()) IRVisitorBase<void, T>::Visit(&node->false_case, &node->false_case);
}
template <typename T>
void IRMutator<T>::Visit(const Block *expr, T op) {
Expand Down
4 changes: 3 additions & 1 deletion cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using poly::Stage;
struct MarkVectorizeMutator : public ir::IRMutator<Expr*> {
const std::map<std::string, ir::VectorizeInfo>& vectorizes;

MarkVectorizeMutator(const std::map<std::string /*tensor name*/, ir::VectorizeInfo>& vectorizes)
explicit MarkVectorizeMutator(const std::map<std::string /*tensor name*/, ir::VectorizeInfo>& vectorizes)
: vectorizes(vectorizes) {}

void operator()(Expr* expr) { ir::IRMutator<Expr*>::Visit(expr, expr); }
Expand Down Expand Up @@ -180,10 +180,12 @@ Expr LowerGroup(const poly::ScheduleGroup& group, const std::map<std::string, Ex
ir::Expr e;
poly::IslAstNodeToCinnExpr(ast, &e);

/*
for (auto& stage : stages) {
VLOG(3) << "run Split separation on " << stage->id() << " " << stage->split_strageties().size() << " strategies";
SplitExpandMutator(stage->id(), stage->split_strageties())(&e);
}
*/

// replace call to the corresponding statement
for (auto& statement : tuple_to_expr) {
Expand Down
7 changes: 4 additions & 3 deletions cinn/optim/ir_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
}

Expr Visit(const IfThenElse* op) override {
auto condition = Visit(&op->condition);
auto true_case = Visit(&op->true_case);
auto false_case = Visit(&op->false_case);
auto condition = Visit(&op->condition);
auto true_case = Visit(&op->true_case);
Expr false_case;
if (op->false_case.defined()) Visit(&op->false_case);
return IfThenElse::Make(condition, true_case, false_case);
}

Expand Down
2 changes: 2 additions & 0 deletions cinn/optim/replace_call_with_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> {

void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate) {
ReplaceCallWithExprModifier modifier(statement, candidate);
LOG(INFO) << "statement " << statement;
LOG(INFO) << "candidate " << candidate;
modifier(e);
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/ast_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AstGen {
AstGen(const isl::set& context, const std::vector<Stage*>& stages, const poly::ScheduleGroup& group);

/**
* Set forloop iterator names.
* Set for-loop iterator names.
* @param names
* @return AstGen itself.
*/
Expand Down
48 changes: 48 additions & 0 deletions cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <glog/logging.h>
#include <isl/cpp.h>

#include <algorithm>

#include "cinn/utils/string.h"

namespace cinn {
Expand Down Expand Up @@ -97,5 +99,51 @@ isl::set SetGetDims(isl::set set, const std::vector<int> &dims) {
return set.apply(transform);
}

isl_set *isl_get_precending_aixs(isl_set *set, int level) {
int n = isl_set_dim(set, isl_dim_set);
CHECK_LT(level, n);

std::vector<std::string> domain_iterators;
std::vector<std::string> range_iterators;

for (int i = 0; i < n; i++) {
domain_iterators.push_back("i" + std::to_string(i));
}

for (int i = 0; i < level; i++) {
range_iterators.push_back("i" + std::to_string(i));
}

const char *statement = isl_set_get_tuple_name(set);

std::string repr = utils::StringFormat("{ %s[%s] -> %s[%s] }",
statement,
utils::Join(domain_iterators, ", ").c_str(),
statement,
utils::Join(range_iterators, ", ").c_str());
auto transform = isl::manage(isl_map_read_from_str(isl_set_get_ctx(set), repr.c_str()));

return isl_set_apply(set, transform.release());
}

int isl_max_level_compatible(isl_set *a, isl_set *b) {
int an = isl_set_dim(a, isl_dim_set);
int bn = isl_set_dim(b, isl_dim_set);
CHECK_GT(an, 0);
CHECK_GT(bn, 0);

int compatible_level = -1;
for (int i = 0; i < std::min(an, bn); i++) {
isl::set a_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(a), i));
isl::set b_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(b), i));
if (isl_set_is_equal(a_prefix.get(), b_prefix.get()))
compatible_level = i;
else
break;
}

return compatible_level;
}

} // namespace poly
} // namespace cinn
5 changes: 5 additions & 0 deletions cinn/poly/isl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ isl::set SetGetDims(isl::set set, const std::vector<int>& dims);
//! Get a representation of the tuple in the map.
std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type type);

isl_set* __isl_give isl_get_precending_aixs(isl_set* __isl_take set, int level);

//! Get the maximum level of axis that is has the same domain.
int isl_max_level_compatible(isl_set* __isl_keep a, isl_set* __isl_keep b);

} // namespace poly
} // namespace cinn
26 changes: 26 additions & 0 deletions cinn/poly/poly_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
#include <glog/logging.h>

#include <deque>
#include <limits>
#include <map>
#include <set>

#include "cinn/poly/isl_utils.h"

namespace cinn {
namespace poly {
Expand Down Expand Up @@ -275,5 +280,26 @@ void PolyScheduler::ScheduleGroups() {
}
}

void PolyGroupScheduler::Build() {
for (int i = 0; i < stages_.size() - 1; i++) {
Stage* a = stages_[i];
Stage* b = stages_[i + 1];

auto a_set = a->transformed_domain();
auto b_set = b->transformed_domain();

int max_precending_level = std::max(isl_max_level_compatible(a_set.get(), b_set.get()), 0);
After(*a, *b, max_precending_level);
}
}

PolyGroupScheduler::PolyGroupScheduler(const std::vector<Stage*>& stages) : stages_(stages) {
CHECK_GT(stages.size(), 0) << "No stage is provided";
for (auto* stage : stages) {
AddStage(*stage);
}
FinishStageAdd();
}

} // namespace poly
} // namespace cinn
30 changes: 5 additions & 25 deletions cinn/poly/poly_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,18 @@ namespace cinn {
namespace poly {

/**
* Schedule a single group with iterator domain considered.
* Schedule a single group with iterator domain considered and follow the stage order.
*/
class PolyGroupScheduler : public SchedulerBase {
public:
//! Constructor, this will build a DAG based on the stages.
explicit PolyGroupScheduler(const std::vector<Stage *> &stages) {
CHECK_GT(stages.size(), 0) << "No stage is provided";
for (auto *stage : stages) {
AddStage(*stage);
}
FinishStageAdd();
}
explicit PolyGroupScheduler(const std::vector<Stage *> &stages);

//! Build the schedule, that is set the time schedule following each edge.
void Build() {
ScheduleGraph::node_order_t node_order;
ScheduleGraph::edge_order_t edge_order;
CHECK(!schedule_graph_.nodes().empty());
std::tie(node_order, edge_order) = schedule_graph_.topological_order();
for (auto *edge : edge_order) {
auto *schedule_edge = edge->as<ScheduleGraphEdge>();
auto *a_node = schedule_graph_.RetriveNode(edge->source()->As<ScheduleGraphNode>()->time_schedule.id())
->As<ScheduleGraphNode>();
auto *b_node = schedule_graph_.RetriveNode(edge->sink()->As<ScheduleGraphNode>()->time_schedule.id())
->As<ScheduleGraphNode>();
CHECK(a_node);
CHECK(b_node);
void Build();

int level = schedule_edge->level;
b_node->time_schedule.OrderAfter(a_node->time_schedule, level);
}
}
private:
const std::vector<Stage *> &stages_;
};

/**
Expand Down
49 changes: 1 addition & 48 deletions cinn/poly/poly_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,5 @@
#include <gtest/gtest.h>

namespace cinn {
namespace poly {

TEST(Scheduler, basic) {
isl::ctx ctx(Context::Global().isl_ctx());
isl::set A_set(ctx, "[]->{ A[i,j]: 0<i,j<100 }");
auto A = Stage::New(A_set);
isl::set B_set(ctx, "[]->{ B[i,j]: 0<i,j<100 }");
auto B = Stage::New(B_set);
LOG(INFO) << A->transform();

PolyGroupScheduler scheduler({A.get(), B.get()});
scheduler.After(*A, *B, 1);
scheduler.Build();

auto schedule = scheduler.schedule_map();

EXPECT_EQ(utils::GetStreamCnt(schedule["A"]), "{ A[i, j] -> [t0 = 0, d0 = i, t1 = 0, d1 = j] }");
EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), "{ B[i, j] -> [t0 = 0, d0 = i, t1 = 1, d1 = j] }");

for (auto item : schedule) {
LOG(INFO) << item.first << " " << item.second;
}
}

TEST(Scheduler, basic_with_transform) {
isl::ctx ctx = Context::Global().isl_ctx();
auto A = Stage::New(isl::set(ctx, "[]->{ A[i,j]: 0<i,j<100 }"));
auto B = Stage::New(isl::set(ctx, "[]->{ B[i,j]: 0<i,j<100 }"));
auto x = A->Split("i", 4);
LOG(INFO) << A->transform();
B->Split(Iterator("j"), 6);
LOG(INFO) << B->transform();

PolyGroupScheduler scheduler({A.get(), B.get()});
scheduler.After(*A, *B, 1);
scheduler.Build();
auto schedule = scheduler.schedule_map();
for (auto item : schedule) {
LOG(INFO) << item.first << " " << item.second;
}

EXPECT_EQ(utils::GetStreamCnt(schedule["A"]),
"{ A[i_outer, i_inner, j] -> [t0 = 0, d0 = i_outer, t1 = 0, d1 = i_inner, t2 = 0, d2 = j] }");
EXPECT_EQ(utils::GetStreamCnt(schedule["B"]),
"{ B[i, j_outer, j_inner] -> [t0 = 0, d0 = i, t1 = 1, d1 = j_outer, t2 = 0, d2 = j_inner] }");
}

} // namespace poly
namespace poly {} // namespace poly
} // namespace cinn
15 changes: 12 additions & 3 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,20 @@ std::vector<Stage *> GatherStagesInTensors(const std::vector<ir::Tensor> &xs, bo

std::map<std::string, isl::map> CollectScheduleMapFromGroup(const ScheduleGroup &group) {
std::map<std::string, isl::map> map;

std::vector<Stage *> stages;
LOG(INFO) << "Group to schedule as:";
for (auto &node : group.nodes) {
auto *schedule_node = node->As<ScheduleGraphNode>();
map[schedule_node->id()] = schedule_node->time_schedule.to_isl(Context::Global().isl_ctx());
LOG(INFO) << node->stage->id();
auto *schedule_node = node->As<ScheduleGraphNode>();
CHECK(node->stage);
stages.push_back(node->stage);
}
return map;

PolyGroupScheduler group_scheduler(stages);
group_scheduler.Build();

return group_scheduler.schedule_map();
}

void SchedulerBase::AddStage(const Stage &x) {
Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ struct ScheduleGraphEdge : public common::GraphEdge {
*/
struct ScheduleGraphNode : public common::GraphNode {
TimeSchedule time_schedule;
Stage *stage;
Stage *stage{};

//! NOTE this id is not human-readable.
// std::string id() const override { return std::to_string(reinterpret_cast<size_t>(this)); }
Expand Down
9 changes: 9 additions & 0 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ void Stage::Vectorize(int level, int factor) {
vectorize_info_.set(level + 1 /*inner*/, factor);
}

void Stage::Vectorize(const std::string &axis, int factor) {
auto dims = GetDimNames(transformed_domain());
auto it = std::find(dims.begin(), dims.end(), axis);
CHECK(it != dims.end()) << "No dimension called " << axis;
Vectorize(std::distance(dims.begin(), it), factor);
}

void Stage::Vectorize(const Iterator &axis, int factor) { return Vectorize(axis.id, factor); }

std::string Stage::ith_dim_name(int level) {
auto dims = GetDimNames(transformed_domain());
CHECK_LT(level, dims.size());
Expand Down
Loading

0 comments on commit 6e6d79e

Please sign in to comment.