From 268d75b69610a310f2cd36947e13af9451f9607a Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 19 Apr 2024 15:05:12 +0800 Subject: [PATCH] [CINN] Add FuseParallelMatmulPass (#63623) * [CINN] Add FuseParallelMatmulPass * delete CHECK_EQ * pass test_llama_mlp_dy unittest --- .../operator/transforms/add_cinn_pass.cc | 2 + .../transforms/fuse_parallel_matmul_pass.cc | 171 ++++++++++++++++++ .../transforms/fuse_parallel_matmul_pass.h | 28 +++ test/cpp/pir/cinn/CMakeLists.txt | 6 +- .../cinn/merge_parallel_matmul_pass_test.cc | 111 ++++++++++++ .../ir/pir/cinn/symbolic/test_llama_mlp_dy.py | 7 +- 6 files changed, 322 insertions(+), 3 deletions(-) create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h create mode 100644 test/cpp/pir/cinn/merge_parallel_matmul_pass_test.cc diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 7a32f197d2d027..d695be6a4f777f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -31,6 +31,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.h" @@ -80,6 +81,7 @@ void ApplyPdToCinnPass( const std::function()>& CreatePassManager) { std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass()); pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pass_manager->Run(program); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc new file mode 100644 index 00000000000000..abeffecd76b974 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.cc @@ -0,0 +1,171 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/include/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/include/pattern_rewrite/pattern_match.h" +#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class MergeParallelMatmulPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::MatmulOp matmul_op, + pir::PatternRewriter& rewriter) const override { + auto ValidMatmulTranspose = [&](pir::Operation* op) -> bool { + if (!op->dyn_cast()) { + return false; + } + bool trans_x = + op->attribute("transpose_x").dyn_cast().data(); + bool trans_y = + op->attribute("transpose_y").dyn_cast().data(); + return !trans_x && !trans_y; + }; + if (!ValidMatmulTranspose(matmul_op)) { + return false; + } + + auto VectorPrefixEqual = [](const std::vector& a, + const std::vector& b) { + if (a.size() != b.size()) { + return false; + } + for (int i = 0; i < a.size() - 1; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; + }; + + auto input_x = matmul_op.operand_source(0); + const std::vector merge_ops = [&]() { + std::vector ret; + std::optional> pre_dim; + std::vector cur_dim; + for (auto it = input_x.use_begin(); it != input_x.use_end(); ++it) { + if (!ValidMatmulTranspose(it->owner())) { + continue; + } + if (!pre_dim.has_value()) { + pre_dim = ::common::vectorize( + it->owner() + ->operand_source(1) + .type() + .dyn_cast() + .dims()); + } + cur_dim = ::common::vectorize( + it->owner() + ->operand_source(1) + .type() + .dyn_cast() + .dims()); + if (VectorPrefixEqual(pre_dim.value(), cur_dim)) { + ret.push_back(it->owner()); + } + } + return ret; + }(); + if (merge_ops.size() <= 1) { + return false; + } + + const std::vector combine_ins = [&]() { + std::vector ret; + for (pir::Operation* op : merge_ops) { + ret.push_back(op->operand_source(1)); + } + return ret; + }(); + const std::vector combine_shapes = [&]() { + std::vector ret{0}; + std::int64_t accumulate = 0; + for (pir::Value input : combine_ins) { + auto shape = + input.type().dyn_cast().dims(); + accumulate += shape[shape.size() - 1]; + ret.push_back(accumulate); + } + return ret; + }(); + + auto combine_out = rewriter.Build(combine_ins).result(0); + auto concat_out = + rewriter.Build(combine_out, -1).result(0); + auto matmul_out = + rewriter.Build(input_x, concat_out) + .result(0); + + for (size_t i = 0; i < merge_ops.size(); ++i) { + auto split_out = + rewriter + .Build( + matmul_out, + std::vector{ + matmul_out.type() + .dyn_cast() + .dims() + .size() - + 1}, + std::vector{combine_shapes[i]}, + std::vector{combine_shapes[i + 1]}, + std::vector{}, + std::vector{}) + .result(0); + + rewriter.ReplaceAllUsesWith(merge_ops[i]->result(0), split_out); + rewriter.EraseOp(merge_ops[i]); + } + + return true; + } +}; + +class FuseParallelMatmulPass : public pir::PatternRewritePass { + public: + FuseParallelMatmulPass() + : pir::PatternRewritePass("fuse_parallel_matmul_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + return ps; + } +}; + +std::unique_ptr CreateFuseParallelMatmulPass() { + return std::make_unique(); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h new file mode 100644 index 00000000000000..319bb9b3fa3456 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/pass/pass.h" + +namespace cinn { +namespace dialect { +namespace ir { + +IR_API std::unique_ptr CreateFuseParallelMatmulPass(); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index bb68da48a82454..017b41b12078e6 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -28,6 +28,9 @@ if(WITH_TESTING AND WITH_CINN) paddle_test(test_generate_shape_util_test SRCS generate_shape_util_test.cc DEPS cinn_op_dialect) + paddle_test(merge_parallel_matmul_pass_test SRCS + merge_parallel_matmul_pass_test.cc) + # DO NOT forget add test name here, otherwise it will not be executed in # CINN CI. set(cinn_unit_tests @@ -40,7 +43,8 @@ if(WITH_TESTING AND WITH_CINN) test_group_op test_pir_build_cinn_pass test_compilation_task - test_generate_shape_util_test) + test_generate_shape_util_test + merge_parallel_matmul_pass_test) foreach(test_name ${cinn_unit_tests}) get_property( diff --git a/test/cpp/pir/cinn/merge_parallel_matmul_pass_test.cc b/test/cpp/pir/cinn/merge_parallel_matmul_pass_test.cc new file mode 100644 index 00000000000000..6ae6c801ee6649 --- /dev/null +++ b/test/cpp/pir/cinn/merge_parallel_matmul_pass_test.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_dialect.h" +#include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp x = + builder.Build(std::vector{32, 32}, 0.5); + + paddle::dialect::FullOp weight_1 = + builder.Build(std::vector{32, 32}, 0.5); + paddle::dialect::FullOp weight_2 = + builder.Build(std::vector{32, 64}, 0.5); + paddle::dialect::FullOp weight_3 = builder.Build( + std::vector{32, 128}, 0.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(x.out(), weight_1.out()); + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(x.out(), weight_2.out()); + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(x.out(), weight_3.out()); + + builder.Build(matmul_op1.out(), "x", 0); + builder.Build(matmul_op2.out(), "y", 1); + builder.Build(matmul_op3.out(), "z", 1); +} + +TEST(Cinn, FuseMatmul) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + ASSERT_EQ((program.block()->size()), 10u); + + pir::PassManager pm(ctx); + pm.AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + ASSERT_EQ((pm.Run(&program)), true); + ASSERT_EQ((program.block()->size()), 20u); +} + +// [64, 32] * [16, 32, 32] => [16, 64, 32] +void BuildBatchProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp x = + builder.Build(std::vector{64, 32}, 0.5); + + paddle::dialect::FullOp weight_1 = builder.Build( + std::vector{16, 32, 32}, 0.5); + paddle::dialect::FullOp weight_2 = builder.Build( + std::vector{16, 32, 64}, 0.5); + paddle::dialect::FullOp weight_3 = builder.Build( + std::vector{16, 32, 128}, 0.5); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build(x.out(), weight_1.out()); + paddle::dialect::MatmulOp matmul_op2 = + builder.Build(x.out(), weight_2.out()); + paddle::dialect::MatmulOp matmul_op3 = + builder.Build(x.out(), weight_3.out()); + + builder.Build(matmul_op1.out(), "x", 0); + builder.Build(matmul_op2.out(), "y", 1); + builder.Build(matmul_op3.out(), "z", 1); +} + +TEST(Cinn, FuseBatchMatmul) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildBatchProgram(builder); + ASSERT_EQ((program.block()->size()), 10u); + + pir::PassManager pm(ctx); + pm.AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass()); + pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + ASSERT_EQ((pm.Run(&program)), true); + ASSERT_EQ((program.block()->size()), 20u); +} diff --git a/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py b/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py index 96cbbd80767023..6382ed53d6d48d 100644 --- a/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py +++ b/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py @@ -64,8 +64,11 @@ def prepare_data(self): self.hidden_states.stop_gradient = False def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 1) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + # FusionOp split by matmul: + # FusionOp1: concat + # FusionOp2: slice, generate_shape, etc. + utils.check_jit_kernel_number(static_fn, 2) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2}) def eval(self, use_cinn): paddle.seed(2024)