Skip to content

Commit

Permalink
[CINN] Add FuseParallelMatmulPass (#63623)
Browse files Browse the repository at this point in the history
* [CINN] Add FuseParallelMatmulPass

* delete CHECK_EQ

* pass test_llama_mlp_dy unittest
  • Loading branch information
jiahy0825 authored Apr 19, 2024
1 parent 69d3115 commit 268d75b
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 3 deletions.
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.h"
Expand Down Expand Up @@ -80,6 +81,7 @@ void ApplyPdToCinnPass(
const std::function<std::shared_ptr<::pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
pass_manager->Run(program);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/include/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"

namespace cinn {
namespace dialect {
namespace ir {

class MergeParallelMatmulPattern
: public pir::OpRewritePattern<paddle::dialect::MatmulOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MatmulOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::MatmulOp matmul_op,
pir::PatternRewriter& rewriter) const override {
auto ValidMatmulTranspose = [&](pir::Operation* op) -> bool {
if (!op->dyn_cast<paddle::dialect::MatmulOp>()) {
return false;
}
bool trans_x =
op->attribute("transpose_x").dyn_cast<pir::BoolAttribute>().data();
bool trans_y =
op->attribute("transpose_y").dyn_cast<pir::BoolAttribute>().data();
return !trans_x && !trans_y;
};
if (!ValidMatmulTranspose(matmul_op)) {
return false;
}

auto VectorPrefixEqual = [](const std::vector<std::int64_t>& a,
const std::vector<std::int64_t>& b) {
if (a.size() != b.size()) {
return false;
}
for (int i = 0; i < a.size() - 1; ++i) {
if (a[i] != b[i]) {
return false;
}
}
return true;
};

auto input_x = matmul_op.operand_source(0);
const std::vector<pir::Operation*> merge_ops = [&]() {
std::vector<pir::Operation*> ret;
std::optional<std::vector<std::int64_t>> pre_dim;
std::vector<std::int64_t> cur_dim;
for (auto it = input_x.use_begin(); it != input_x.use_end(); ++it) {
if (!ValidMatmulTranspose(it->owner())) {
continue;
}
if (!pre_dim.has_value()) {
pre_dim = ::common::vectorize(
it->owner()
->operand_source(1)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims());
}
cur_dim = ::common::vectorize(
it->owner()
->operand_source(1)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims());
if (VectorPrefixEqual(pre_dim.value(), cur_dim)) {
ret.push_back(it->owner());
}
}
return ret;
}();
if (merge_ops.size() <= 1) {
return false;
}

const std::vector<pir::Value> combine_ins = [&]() {
std::vector<pir::Value> ret;
for (pir::Operation* op : merge_ops) {
ret.push_back(op->operand_source(1));
}
return ret;
}();
const std::vector<std::int64_t> combine_shapes = [&]() {
std::vector<std::int64_t> ret{0};
std::int64_t accumulate = 0;
for (pir::Value input : combine_ins) {
auto shape =
input.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
accumulate += shape[shape.size() - 1];
ret.push_back(accumulate);
}
return ret;
}();

auto combine_out = rewriter.Build<pir::CombineOp>(combine_ins).result(0);
auto concat_out =
rewriter.Build<paddle::dialect::ConcatOp>(combine_out, -1).result(0);
auto matmul_out =
rewriter.Build<paddle::dialect::MatmulOp>(input_x, concat_out)
.result(0);

for (size_t i = 0; i < merge_ops.size(); ++i) {
auto split_out =
rewriter
.Build<paddle::dialect::SliceOp>(
matmul_out,
std::vector<std::int64_t>{
matmul_out.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
.size() -
1},
std::vector<std::int64_t>{combine_shapes[i]},
std::vector<int64_t>{combine_shapes[i + 1]},
std::vector<std::int64_t>{},
std::vector<std::int64_t>{})
.result(0);

rewriter.ReplaceAllUsesWith(merge_ops[i]->result(0), split_out);
rewriter.EraseOp(merge_ops[i]);
}

return true;
}
};

class FuseParallelMatmulPass : public pir::PatternRewritePass {
public:
FuseParallelMatmulPass()
: pir::PatternRewritePass("fuse_parallel_matmul_pass", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<MergeParallelMatmulPattern>(context);
return ps;
}
};

std::unique_ptr<pir::Pass> CreateFuseParallelMatmulPass() {
return std::make_unique<FuseParallelMatmulPass>();
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/include/pass/pass.h"

namespace cinn {
namespace dialect {
namespace ir {

IR_API std::unique_ptr<pir::Pass> CreateFuseParallelMatmulPass();

} // namespace ir
} // namespace dialect
} // namespace cinn
6 changes: 5 additions & 1 deletion test/cpp/pir/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ if(WITH_TESTING AND WITH_CINN)
paddle_test(test_generate_shape_util_test SRCS generate_shape_util_test.cc
DEPS cinn_op_dialect)

paddle_test(merge_parallel_matmul_pass_test SRCS
merge_parallel_matmul_pass_test.cc)

# DO NOT forget add test name here, otherwise it will not be executed in
# CINN CI.
set(cinn_unit_tests
Expand All @@ -40,7 +43,8 @@ if(WITH_TESTING AND WITH_CINN)
test_group_op
test_pir_build_cinn_pass
test_compilation_task
test_generate_shape_util_test)
test_generate_shape_util_test
merge_parallel_matmul_pass_test)

foreach(test_name ${cinn_unit_tests})
get_property(
Expand Down
111 changes: 111 additions & 0 deletions test/cpp/pir/cinn/merge_parallel_matmul_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>

#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/pass/pass_manager.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"

void BuildProgram(pir::Builder &builder) { // NOLINT
paddle::dialect::FullOp x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{32, 32}, 0.5);

paddle::dialect::FullOp weight_1 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{32, 32}, 0.5);
paddle::dialect::FullOp weight_2 =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{32, 64}, 0.5);
paddle::dialect::FullOp weight_3 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{32, 128}, 0.5);

paddle::dialect::MatmulOp matmul_op1 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_1.out());
paddle::dialect::MatmulOp matmul_op2 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_2.out());
paddle::dialect::MatmulOp matmul_op3 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_3.out());

builder.Build<paddle::dialect::FetchOp>(matmul_op1.out(), "x", 0);
builder.Build<paddle::dialect::FetchOp>(matmul_op2.out(), "y", 1);
builder.Build<paddle::dialect::FetchOp>(matmul_op3.out(), "z", 1);
}

TEST(Cinn, FuseMatmul) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();
pir::Program program(ctx);
pir::Builder builder = pir::Builder(ctx, program.block());
BuildProgram(builder);
ASSERT_EQ((program.block()->size()), 10u);

pir::PassManager pm(ctx);
pm.AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();

ASSERT_EQ((pm.Run(&program)), true);
ASSERT_EQ((program.block()->size()), 20u);
}

// [64, 32] * [16, 32, 32] => [16, 64, 32]
void BuildBatchProgram(pir::Builder &builder) { // NOLINT
paddle::dialect::FullOp x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 32}, 0.5);

paddle::dialect::FullOp weight_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{16, 32, 32}, 0.5);
paddle::dialect::FullOp weight_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{16, 32, 64}, 0.5);
paddle::dialect::FullOp weight_3 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{16, 32, 128}, 0.5);

paddle::dialect::MatmulOp matmul_op1 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_1.out());
paddle::dialect::MatmulOp matmul_op2 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_2.out());
paddle::dialect::MatmulOp matmul_op3 =
builder.Build<paddle::dialect::MatmulOp>(x.out(), weight_3.out());

builder.Build<paddle::dialect::FetchOp>(matmul_op1.out(), "x", 0);
builder.Build<paddle::dialect::FetchOp>(matmul_op2.out(), "y", 1);
builder.Build<paddle::dialect::FetchOp>(matmul_op3.out(), "z", 1);
}

TEST(Cinn, FuseBatchMatmul) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();
pir::Program program(ctx);
pir::Builder builder = pir::Builder(ctx, program.block());
BuildBatchProgram(builder);
ASSERT_EQ((program.block()->size()), 10u);

pir::PassManager pm(ctx);
pm.AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();

ASSERT_EQ((pm.Run(&program)), true);
ASSERT_EQ((program.block()->size()), 20u);
}
7 changes: 5 additions & 2 deletions test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ def prepare_data(self):
self.hidden_states.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})
# FusionOp split by matmul:
# FusionOp1: concat
# FusionOp2: slice, generate_shape, etc.
utils.check_jit_kernel_number(static_fn, 2)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2})

def eval(self, use_cinn):
paddle.seed(2024)
Expand Down

0 comments on commit 268d75b

Please sign in to comment.