From 6e6d79e5445ad0a8bd381135a1faf56aa3a533a4 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Sat, 21 Mar 2020 23:05:31 +0800 Subject: [PATCH] add matmul vectorize (#96) --- cinn/backends/_x86_builtin_source.cc | 8 ++ cinn/backends/codegen_c_x86.cc | 9 +-- cinn/common/graph_utils.h | 2 +- cinn/ir/ir_mutator.h | 2 +- cinn/lang/lower.cc | 4 +- cinn/optim/ir_copy.cc | 7 +- cinn/optim/replace_call_with_expr.cc | 2 + cinn/poly/ast_gen.h | 2 +- cinn/poly/isl_utils.cc | 48 ++++++++++++ cinn/poly/isl_utils.h | 5 ++ cinn/poly/poly_scheduler.cc | 26 ++++++ cinn/poly/poly_scheduler.h | 30 ++----- cinn/poly/poly_scheduler_test.cc | 49 +----------- cinn/poly/schedule.cc | 15 +++- cinn/poly/schedule.h | 2 +- cinn/poly/stage.cc | 9 +++ cinn/poly/stage.h | 12 ++- cinn/runtime/cinn_x86_device_impl.cc | 1 + cinn/utils/CMakeLists.txt | 2 +- cinn/utils/timer.cc | 17 ++++ cinn/utils/timer.h | 21 +++++ tests/CMakeLists.txt | 2 + tests/test02_matmul_case.cc | 48 +++++++++--- tests/test02_matmul_main.cc | 113 +++++++++++++++++++++++++-- 24 files changed, 321 insertions(+), 115 deletions(-) create mode 100644 cinn/utils/timer.cc create mode 100644 cinn/utils/timer.h diff --git a/cinn/backends/_x86_builtin_source.cc b/cinn/backends/_x86_builtin_source.cc index 6a2a3744d29c92..f014730f1ff721 100644 --- a/cinn/backends/_x86_builtin_source.cc +++ b/cinn/backends/_x86_builtin_source.cc @@ -329,6 +329,14 @@ inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_ inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } // @} +//! mul +// @{ +inline __m256 cinn_avx256_mul(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); } +inline __m256d cinn_avx256_mul(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); } +inline __m512 cinn_avx512_mul(const __m512& a, const __m512& b) { return _mm512_mul_ps(a, b); } +inline __m512d cinn_avx512_mul(const __m512d& a, const __m512d& b) { return _mm512_mul_pd(a, b); } +// @} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// )END Predefined utilities in CINN //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cinn/backends/codegen_c_x86.cc b/cinn/backends/codegen_c_x86.cc index 34a6b000737756..a88e63046d8dcf 100644 --- a/cinn/backends/codegen_c_x86.cc +++ b/cinn/backends/codegen_c_x86.cc @@ -80,14 +80,7 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) { NOT_IMPLEMENTED } } else { - auto *load_n = op->As(); - - if (load_n) { - Visit(load_n); - return; - } - - NOT_IMPLEMENTED + Print(*op); } } diff --git a/cinn/common/graph_utils.h b/cinn/common/graph_utils.h index 220445fbbf1f6f..9b3c6ad373e48f 100644 --- a/cinn/common/graph_utils.h +++ b/cinn/common/graph_utils.h @@ -3,6 +3,7 @@ #include +#include #include #include #include @@ -13,7 +14,6 @@ #include #include -#include #include "cinn/common/object.h" #include "cinn/common/shared.h" diff --git a/cinn/ir/ir_mutator.h b/cinn/ir/ir_mutator.h index 28e2a5064b6d4f..16b86e84bf0d69 100644 --- a/cinn/ir/ir_mutator.h +++ b/cinn/ir/ir_mutator.h @@ -83,7 +83,7 @@ void IRMutator::Visit(const IfThenElse *expr, T op) { auto *node = op->template As(); IRVisitorBase::Visit(&node->condition, &node->condition); IRVisitorBase::Visit(&node->true_case, &node->true_case); - IRVisitorBase::Visit(&node->false_case, &node->false_case); + if (node->false_case.defined()) IRVisitorBase::Visit(&node->false_case, &node->false_case); } template void IRMutator::Visit(const Block *expr, T op) { diff --git a/cinn/lang/lower.cc b/cinn/lang/lower.cc index 28ea499cfee362..958abb2107707a 100644 --- a/cinn/lang/lower.cc +++ b/cinn/lang/lower.cc @@ -23,7 +23,7 @@ using poly::Stage; struct MarkVectorizeMutator : public ir::IRMutator { const std::map& vectorizes; - MarkVectorizeMutator(const std::map& vectorizes) + explicit MarkVectorizeMutator(const std::map& vectorizes) : vectorizes(vectorizes) {} void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } @@ -180,10 +180,12 @@ Expr LowerGroup(const poly::ScheduleGroup& group, const std::mapid() << " " << stage->split_strageties().size() << " strategies"; SplitExpandMutator(stage->id(), stage->split_strageties())(&e); } + */ // replace call to the corresponding statement for (auto& statement : tuple_to_expr) { diff --git a/cinn/optim/ir_copy.cc b/cinn/optim/ir_copy.cc index ae6becf8b0ef09..fca82e1802d94b 100644 --- a/cinn/optim/ir_copy.cc +++ b/cinn/optim/ir_copy.cc @@ -43,9 +43,10 @@ struct IRCopyVisitor : public ir::IRVisitorBase { } Expr Visit(const IfThenElse* op) override { - auto condition = Visit(&op->condition); - auto true_case = Visit(&op->true_case); - auto false_case = Visit(&op->false_case); + auto condition = Visit(&op->condition); + auto true_case = Visit(&op->true_case); + Expr false_case; + if (op->false_case.defined()) Visit(&op->false_case); return IfThenElse::Make(condition, true_case, false_case); } diff --git a/cinn/optim/replace_call_with_expr.cc b/cinn/optim/replace_call_with_expr.cc index a46fecd141c1f8..b932dfe8316ece 100644 --- a/cinn/optim/replace_call_with_expr.cc +++ b/cinn/optim/replace_call_with_expr.cc @@ -37,6 +37,8 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> { void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate) { ReplaceCallWithExprModifier modifier(statement, candidate); + LOG(INFO) << "statement " << statement; + LOG(INFO) << "candidate " << candidate; modifier(e); } diff --git a/cinn/poly/ast_gen.h b/cinn/poly/ast_gen.h index c8a45c25071bf4..b3ed263eae263b 100644 --- a/cinn/poly/ast_gen.h +++ b/cinn/poly/ast_gen.h @@ -27,7 +27,7 @@ class AstGen { AstGen(const isl::set& context, const std::vector& stages, const poly::ScheduleGroup& group); /** - * Set forloop iterator names. + * Set for-loop iterator names. * @param names * @return AstGen itself. */ diff --git a/cinn/poly/isl_utils.cc b/cinn/poly/isl_utils.cc index 21510cff9e986d..1abcc8662f8d20 100644 --- a/cinn/poly/isl_utils.cc +++ b/cinn/poly/isl_utils.cc @@ -3,6 +3,8 @@ #include #include +#include + #include "cinn/utils/string.h" namespace cinn { @@ -97,5 +99,51 @@ isl::set SetGetDims(isl::set set, const std::vector &dims) { return set.apply(transform); } +isl_set *isl_get_precending_aixs(isl_set *set, int level) { + int n = isl_set_dim(set, isl_dim_set); + CHECK_LT(level, n); + + std::vector domain_iterators; + std::vector range_iterators; + + for (int i = 0; i < n; i++) { + domain_iterators.push_back("i" + std::to_string(i)); + } + + for (int i = 0; i < level; i++) { + range_iterators.push_back("i" + std::to_string(i)); + } + + const char *statement = isl_set_get_tuple_name(set); + + std::string repr = utils::StringFormat("{ %s[%s] -> %s[%s] }", + statement, + utils::Join(domain_iterators, ", ").c_str(), + statement, + utils::Join(range_iterators, ", ").c_str()); + auto transform = isl::manage(isl_map_read_from_str(isl_set_get_ctx(set), repr.c_str())); + + return isl_set_apply(set, transform.release()); +} + +int isl_max_level_compatible(isl_set *a, isl_set *b) { + int an = isl_set_dim(a, isl_dim_set); + int bn = isl_set_dim(b, isl_dim_set); + CHECK_GT(an, 0); + CHECK_GT(bn, 0); + + int compatible_level = -1; + for (int i = 0; i < std::min(an, bn); i++) { + isl::set a_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(a), i)); + isl::set b_prefix = isl::manage(isl_get_precending_aixs(isl_set_copy(b), i)); + if (isl_set_is_equal(a_prefix.get(), b_prefix.get())) + compatible_level = i; + else + break; + } + + return compatible_level; +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/isl_utils.h b/cinn/poly/isl_utils.h index aab818ff997d8b..1ab9ff1c182426 100644 --- a/cinn/poly/isl_utils.h +++ b/cinn/poly/isl_utils.h @@ -30,5 +30,10 @@ isl::set SetGetDims(isl::set set, const std::vector& dims); //! Get a representation of the tuple in the map. std::string isl_map_get_statement_repr(__isl_keep isl_map* map, isl_dim_type type); +isl_set* __isl_give isl_get_precending_aixs(isl_set* __isl_take set, int level); + +//! Get the maximum level of axis that is has the same domain. +int isl_max_level_compatible(isl_set* __isl_keep a, isl_set* __isl_keep b); + } // namespace poly } // namespace cinn diff --git a/cinn/poly/poly_scheduler.cc b/cinn/poly/poly_scheduler.cc index 0ae7e1d4795ae7..e4ca0ae0e0f324 100644 --- a/cinn/poly/poly_scheduler.cc +++ b/cinn/poly/poly_scheduler.cc @@ -3,6 +3,11 @@ #include #include +#include +#include +#include + +#include "cinn/poly/isl_utils.h" namespace cinn { namespace poly { @@ -275,5 +280,26 @@ void PolyScheduler::ScheduleGroups() { } } +void PolyGroupScheduler::Build() { + for (int i = 0; i < stages_.size() - 1; i++) { + Stage* a = stages_[i]; + Stage* b = stages_[i + 1]; + + auto a_set = a->transformed_domain(); + auto b_set = b->transformed_domain(); + + int max_precending_level = std::max(isl_max_level_compatible(a_set.get(), b_set.get()), 0); + After(*a, *b, max_precending_level); + } +} + +PolyGroupScheduler::PolyGroupScheduler(const std::vector& stages) : stages_(stages) { + CHECK_GT(stages.size(), 0) << "No stage is provided"; + for (auto* stage : stages) { + AddStage(*stage); + } + FinishStageAdd(); +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/poly_scheduler.h b/cinn/poly/poly_scheduler.h index 05f5a4c8a35e7c..cd0a0df6e38e51 100644 --- a/cinn/poly/poly_scheduler.h +++ b/cinn/poly/poly_scheduler.h @@ -18,38 +18,18 @@ namespace cinn { namespace poly { /** - * Schedule a single group with iterator domain considered. + * Schedule a single group with iterator domain considered and follow the stage order. */ class PolyGroupScheduler : public SchedulerBase { public: //! Constructor, this will build a DAG based on the stages. - explicit PolyGroupScheduler(const std::vector &stages) { - CHECK_GT(stages.size(), 0) << "No stage is provided"; - for (auto *stage : stages) { - AddStage(*stage); - } - FinishStageAdd(); - } + explicit PolyGroupScheduler(const std::vector &stages); //! Build the schedule, that is set the time schedule following each edge. - void Build() { - ScheduleGraph::node_order_t node_order; - ScheduleGraph::edge_order_t edge_order; - CHECK(!schedule_graph_.nodes().empty()); - std::tie(node_order, edge_order) = schedule_graph_.topological_order(); - for (auto *edge : edge_order) { - auto *schedule_edge = edge->as(); - auto *a_node = schedule_graph_.RetriveNode(edge->source()->As()->time_schedule.id()) - ->As(); - auto *b_node = schedule_graph_.RetriveNode(edge->sink()->As()->time_schedule.id()) - ->As(); - CHECK(a_node); - CHECK(b_node); + void Build(); - int level = schedule_edge->level; - b_node->time_schedule.OrderAfter(a_node->time_schedule, level); - } - } + private: + const std::vector &stages_; }; /** diff --git a/cinn/poly/poly_scheduler_test.cc b/cinn/poly/poly_scheduler_test.cc index 1105a35ea07724..55f86596084963 100644 --- a/cinn/poly/poly_scheduler_test.cc +++ b/cinn/poly/poly_scheduler_test.cc @@ -3,52 +3,5 @@ #include namespace cinn { -namespace poly { - -TEST(Scheduler, basic) { - isl::ctx ctx(Context::Global().isl_ctx()); - isl::set A_set(ctx, "[]->{ A[i,j]: 0{ B[i,j]: 0transform(); - - PolyGroupScheduler scheduler({A.get(), B.get()}); - scheduler.After(*A, *B, 1); - scheduler.Build(); - - auto schedule = scheduler.schedule_map(); - - EXPECT_EQ(utils::GetStreamCnt(schedule["A"]), "{ A[i, j] -> [t0 = 0, d0 = i, t1 = 0, d1 = j] }"); - EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), "{ B[i, j] -> [t0 = 0, d0 = i, t1 = 1, d1 = j] }"); - - for (auto item : schedule) { - LOG(INFO) << item.first << " " << item.second; - } -} - -TEST(Scheduler, basic_with_transform) { - isl::ctx ctx = Context::Global().isl_ctx(); - auto A = Stage::New(isl::set(ctx, "[]->{ A[i,j]: 0{ B[i,j]: 0Split("i", 4); - LOG(INFO) << A->transform(); - B->Split(Iterator("j"), 6); - LOG(INFO) << B->transform(); - - PolyGroupScheduler scheduler({A.get(), B.get()}); - scheduler.After(*A, *B, 1); - scheduler.Build(); - auto schedule = scheduler.schedule_map(); - for (auto item : schedule) { - LOG(INFO) << item.first << " " << item.second; - } - - EXPECT_EQ(utils::GetStreamCnt(schedule["A"]), - "{ A[i_outer, i_inner, j] -> [t0 = 0, d0 = i_outer, t1 = 0, d1 = i_inner, t2 = 0, d2 = j] }"); - EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), - "{ B[i, j_outer, j_inner] -> [t0 = 0, d0 = i, t1 = 1, d1 = j_outer, t2 = 0, d2 = j_inner] }"); -} - -} // namespace poly +namespace poly {} // namespace poly } // namespace cinn diff --git a/cinn/poly/schedule.cc b/cinn/poly/schedule.cc index 4fc00bf6ef57dc..6e0e152204ad21 100644 --- a/cinn/poly/schedule.cc +++ b/cinn/poly/schedule.cc @@ -138,11 +138,20 @@ std::vector GatherStagesInTensors(const std::vector &xs, bo std::map CollectScheduleMapFromGroup(const ScheduleGroup &group) { std::map map; + + std::vector stages; + LOG(INFO) << "Group to schedule as:"; for (auto &node : group.nodes) { - auto *schedule_node = node->As(); - map[schedule_node->id()] = schedule_node->time_schedule.to_isl(Context::Global().isl_ctx()); + LOG(INFO) << node->stage->id(); + auto *schedule_node = node->As(); + CHECK(node->stage); + stages.push_back(node->stage); } - return map; + + PolyGroupScheduler group_scheduler(stages); + group_scheduler.Build(); + + return group_scheduler.schedule_map(); } void SchedulerBase::AddStage(const Stage &x) { diff --git a/cinn/poly/schedule.h b/cinn/poly/schedule.h index d1ff21853a4374..5d5a081f60c43b 100644 --- a/cinn/poly/schedule.h +++ b/cinn/poly/schedule.h @@ -191,7 +191,7 @@ struct ScheduleGraphEdge : public common::GraphEdge { */ struct ScheduleGraphNode : public common::GraphNode { TimeSchedule time_schedule; - Stage *stage; + Stage *stage{}; //! NOTE this id is not human-readable. // std::string id() const override { return std::to_string(reinterpret_cast(this)); } diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc index 235e65e0c588a6..2542a892d07a87 100644 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -215,6 +215,15 @@ void Stage::Vectorize(int level, int factor) { vectorize_info_.set(level + 1 /*inner*/, factor); } +void Stage::Vectorize(const std::string &axis, int factor) { + auto dims = GetDimNames(transformed_domain()); + auto it = std::find(dims.begin(), dims.end(), axis); + CHECK(it != dims.end()) << "No dimension called " << axis; + Vectorize(std::distance(dims.begin(), it), factor); +} + +void Stage::Vectorize(const Iterator &axis, int factor) { return Vectorize(axis.id, factor); } + std::string Stage::ith_dim_name(int level) { auto dims = GetDimNames(transformed_domain()); CHECK_LT(level, dims.size()); diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index eefc702672e655..8ad3e8fc1e8b47 100644 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -6,8 +6,10 @@ #include #include #include +#include #include #include +#include #include #include "cinn/common/common.h" @@ -80,6 +82,8 @@ class Stage : public Object { * @param level */ void Vectorize(int level, int factor); + void Vectorize(const std::string& axis, int factor); + void Vectorize(const Iterator& axis, int factor); /** * Mark the stage compute at the level of some other stage. @@ -104,7 +108,13 @@ class Stage : public Object { const isl::set& domain() const { return domain_; } const isl::map& transform() const { return transform_; } - isl::set transformed_domain() const { return domain_.apply(transform_); } + isl::set transformed_domain() const { + CHECK(!domain_.is_null()); + CHECK(!transform_.is_null()); + LOG(INFO) << "domain: " << domain_; + LOG(INFO) << "transform: " << transform_; + return domain_.apply(transform_); + } std::vector compute_ats() const; diff --git a/cinn/runtime/cinn_x86_device_impl.cc b/cinn/runtime/cinn_x86_device_impl.cc index 4dee0915ec0207..8a61f6cd0356fa 100644 --- a/cinn/runtime/cinn_x86_device_impl.cc +++ b/cinn/runtime/cinn_x86_device_impl.cc @@ -1,4 +1,5 @@ #include + #include "cinn/runtime/cinn_runtime.h" int cinn_x86_malloc(void* context, cinn_buffer_t* buf) { diff --git a/cinn/utils/CMakeLists.txt b/cinn/utils/CMakeLists.txt index 798edb541fe1b6..ef29845375edad 100644 --- a/cinn/utils/CMakeLists.txt +++ b/cinn/utils/CMakeLists.txt @@ -1,4 +1,4 @@ -set(srcs string.cc functional.cc dot.cc) +set(srcs string.cc functional.cc dot.cc timer.cc) foreach(cpp ${srcs}) set(core_src diff --git a/cinn/utils/timer.cc b/cinn/utils/timer.cc new file mode 100644 index 00000000000000..557f5366c55455 --- /dev/null +++ b/cinn/utils/timer.cc @@ -0,0 +1,17 @@ +#include "cinn/utils/timer.h" + +namespace cinn { +namespace utils { + +float Timer::Stop() { + end_ = std::chrono::system_clock::now(); + auto ts = std::chrono::duration_cast(end_ - start_); + float ms = 1000.f * static_cast(ts.count()) * std::chrono::milliseconds::period::num / + std::chrono::milliseconds::period::den; + return ms; +} + +void Timer::Start() { start_ = std::chrono::system_clock::now(); } + +} // namespace utils +} // namespace cinn diff --git a/cinn/utils/timer.h b/cinn/utils/timer.h new file mode 100644 index 00000000000000..f4c98b3a297da9 --- /dev/null +++ b/cinn/utils/timer.h @@ -0,0 +1,21 @@ +#pragma once +#include +#include //NOLINT +#include + +namespace cinn { +namespace utils { + +class Timer { + public: + Timer() = default; + + void Start(); + float Stop(); + + private: + std::chrono::time_point start_, end_; +}; + +} // namespace utils +} // namespace cinn diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6c5a1150fa36c7..b1a87f17010690 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,5 +18,7 @@ cc_test(test02_matmul_case SRCS test02_matmul_case.cc ${CMAKE_BINARY_DIR}/tests/test02_matmul.cc ${CMAKE_BINARY_DIR}/tests/test02_matmul_tile.cc ${CMAKE_BINARY_DIR}/tests/test02_matmul_split.cc + ${CMAKE_BINARY_DIR}/tests/test02_matmul_block.cc + ${CMAKE_BINARY_DIR}/tests/test02_matmul_vectorize.cc DEPS core) add_dependencies(test02_matmul_case test02_matmul_main) diff --git a/tests/test02_matmul_case.cc b/tests/test02_matmul_case.cc index 07248d864f3427..2e3a83d7013659 100644 --- a/tests/test02_matmul_case.cc +++ b/tests/test02_matmul_case.cc @@ -2,26 +2,26 @@ #include #include "cinn/runtime/cinn_runtime.h" +#include "cinn/utils/timer.h" #include "tests/test02_matmul.h" +#include "tests/test02_matmul_block.h" #include "tests/test02_matmul_split.h" #include "tests/test02_matmul_tile.h" +#include "tests/test02_matmul_vectorize.h" TEST(test02, basic) { - const int M = 1000; - const int N = 400; - const int K = 500; - - auto* A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, K}); - auto* B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {K, N}); - auto* C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); - auto* C1 = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); - auto* C2 = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); + const int M = 1024; + const int N = 1024; + const int K = 1024; + + auto* A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, K}, 32); + auto* B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {K, N}, 32); + auto* C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}, 32); auto* C_target = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); cinn_buffer_malloc(nullptr, C_target); cinn_buffer_malloc(nullptr, C); - cinn_buffer_malloc(nullptr, C1); float* Ad = reinterpret_cast(A->host_memory); float* Bd = reinterpret_cast(B->host_memory); @@ -64,15 +64,37 @@ TEST(test02, basic) { } } + cinn::utils::Timer timer; + + const int repeat = 1; + LOG(INFO) << "Testing matmul_basic"; - matmul(A, B, C); + timer.Start(); + for (int i = 0; i < repeat; i++) matmul(A, B, C); + LOG(INFO) << timer.Stop() / repeat; compare(); LOG(INFO) << "Testing matmul_tile"; - matmul_tile(A, B, C); + timer.Start(); + for (int i = 0; i < repeat; i++) matmul_tile(A, B, C); + LOG(INFO) << timer.Stop() / repeat; compare(); LOG(INFO) << "Testing matmul_split"; - matmul_split(A, B, C); + timer.Start(); + for (int i = 0; i < repeat; i++) matmul_split(A, B, C); + LOG(INFO) << timer.Stop() / repeat; + compare(); + + LOG(INFO) << "Testing matmul_block"; + timer.Start(); + for (int i = 0; i < repeat; i++) matmul_block(A, B, C); + LOG(INFO) << timer.Stop() / repeat; + compare(); + + LOG(INFO) << "Testing matmul_vectorize"; + timer.Start(); + for (int i = 0; i < repeat; i++) matmul_vectorize(A, B, C); + LOG(INFO) << timer.Stop() / repeat; compare(); } diff --git a/tests/test02_matmul_main.cc b/tests/test02_matmul_main.cc index ed67828fe3d276..59eea9807f3155 100644 --- a/tests/test02_matmul_main.cc +++ b/tests/test02_matmul_main.cc @@ -4,10 +4,11 @@ #include "cinn/optim/optimize.h" namespace cinn { +using poly::Iterator; -const int M = 1000; -const int N = 400; -const int K = 500; +const int M = 1024; +const int N = 1024; +const int K = 1024; TEST(test02_matmul, basic) { Placeholder A("A", {M, K}); @@ -83,12 +84,12 @@ TEST(matmul, Split) { target.bits = Target::Bit ::k32; target.os = Target::OS ::Linux; - poly::Iterator i0, i1; + Iterator i0, i1; std::tie(i0, i1) = C->stage()->Split(2, 16); - std::vector iterators({C->stage()->ith_iterator(1), - C->stage()->ith_iterator(0), - C->stage()->ith_iterator(2), - C->stage()->ith_iterator(3)}); + std::vector iterators({C->stage()->ith_iterator(1), + C->stage()->ith_iterator(0), + C->stage()->ith_iterator(2), + C->stage()->ith_iterator(3)}); C->stage()->Reorder(iterators); Module module("module3", target); @@ -104,4 +105,100 @@ TEST(matmul, Split) { compiler.Compile(module, outputs); } +TEST(matmul, Blocking) { + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + Var k(K, "k"); + Buffer C_buf(Float(32)); + + int bn = 32; + + auto C_init = Compute( + {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); + C_init->Bind(C_buf); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return Sum(A(i, k) * B(k, j), k); }, "C", k); + C->Bind(C_buf); + ASSERT_EQ(C->buffer_depended_tensor_names().size(), 1UL); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + + // Blocking by loop tiling. + { + Iterator i_outer, i_inner, j_outer, j_inner; + std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn, bn); + Iterator k_outer, k_inner; + std::tie(k_outer, k_inner) = C->stage()->Split("k", 4); + + C->stage()->Reorder({i_outer, j_outer, k_outer, k_inner, i_inner, j_inner}); + } + + // C_init->stage()->ComputeAt(C->stage(), 3); + + Module module("module_block", target); + auto funcs = Lower("matmul_block", {A, B, C, C_init}); + ASSERT_EQ(funcs.size(), 1UL); + + auto func = Optimize(funcs.front()); + module.Append(ir::LoweredFunc(func.As())); + + CodeGenCX86 compiler(target, CodeGenCX86::Feature::AVX512); + Outputs outputs; + outputs = outputs.c_header("./test02_matmul_block.h").c_source("./test02_matmul_block.cc"); + compiler.Compile(module, outputs); +} + +TEST(matmul, Vectorization) { + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + Var k(K, "k"); + Buffer C_buf(Float(32)); + + int bn = 32; + + auto C_init = Compute( + {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); + C_init->Bind(C_buf); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return Sum(A(i, k) * B(k, j), k); }, "C", k); + C->Bind(C_buf); + ASSERT_EQ(C->buffer_depended_tensor_names().size(), 1UL); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + + // Blocking by loop tiling. + { + Iterator i_outer, i_inner, j_outer, j_inner; + std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn, bn); + Iterator k_outer, k_inner; + std::tie(k_outer, k_inner) = C->stage()->Split("k", 4); + + C->stage()->Reorder({i_outer, j_outer, k_outer, k_inner, i_inner, j_inner}); + + C->stage()->Vectorize(j_inner, 8); + } + + // C_init->stage()->ComputeAt(C->stage(), 3); + + Module module("module_vectorize", target); + auto funcs = Lower("matmul_vectorize", {A, B, C, C_init}); + ASSERT_EQ(funcs.size(), 1UL); + + auto func = Optimize(funcs.front()); + module.Append(ir::LoweredFunc(func.As())); + + CodeGenCX86 compiler(target, CodeGenCX86::Feature::AVX256); + Outputs outputs; + outputs = outputs.c_header("./test02_matmul_vectorize.h").c_source("./test02_matmul_vectorize.cc"); + compiler.Compile(module, outputs); +} + } // namespace cinn