Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference support flash_attn #64213

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ function(copy_part_of_third_party TARGET DST)
SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)

if(WITH_FLASHATTN)
set(dst_dir "${DST}/third_party/install/flashattn")
copy(
${TARGET}
SRCS ${FLASHATTN_INCLUDE_DIR} ${FLASHATTN_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)
endif()

if(NOT PROTOBUF_FOUND OR WIN32)
set(dst_dir "${DST}/third_party/install/protobuf")
copy(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ const std::vector<std::string> kPirGpuPasses{
"conv2d_add_act_fuse_pass",
"conv2d_add_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"fused_flash_attn_pass",
"multihead_matmul_fuse_pass",
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
Expand Down
168 changes: 158 additions & 10 deletions paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@

namespace {

class FlashAttnPatternQscale : public paddle::drr::DrrPatternBase {
// 1. scale after q
// 2. cast before and after softmax
// 3. with mask
class FlashAttnPatternQscaleWithMask : public paddle::drr::DrrPatternBase {
private:
bool softmax_with_cast_;

public:
explicit FlashAttnPatternQscale(bool softmax_with_cast)
explicit FlashAttnPatternQscaleWithMask(bool softmax_with_cast)
: softmax_with_cast_(softmax_with_cast) {}

std::string name() const override { return "FlashAttnPatternQscale"; }
std::string name() const override { return "FlashAttnPatternQscaleWithMask"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern src = ctx->SourcePattern();
Expand Down Expand Up @@ -167,16 +170,19 @@ class FlashAttnPatternQscale : public paddle::drr::DrrPatternBase {

// 1. scale after matmul
// 2. cast before and after softmax
class FlashAttnPatternOutscale : public paddle::drr::DrrPatternBase {
// 3. with mask
class FlashAttnPatternOutscaleWithMask : public paddle::drr::DrrPatternBase {
private:
bool softmax_with_cast_;

public:
explicit FlashAttnPatternOutscale(bool softmax_with_cast)
explicit FlashAttnPatternOutscaleWithMask(bool softmax_with_cast)
: softmax_with_cast_(softmax_with_cast) {}

public:
std::string name() const override { return "FlashAttnPatternOutscale"; }
std::string name() const override {
return "FlashAttnPatternOutscaleWithMask";
}

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern src = ctx->SourcePattern();
Expand Down Expand Up @@ -308,6 +314,136 @@ class FlashAttnPatternOutscale : public paddle::drr::DrrPatternBase {
}
};

// 1. scale after matmul
// 2. cast before and after softmax
// 3. no mask
class FlashAttnPatternOutscaleNoMask : public paddle::drr::DrrPatternBase {
private:
bool softmax_with_cast_;

public:
explicit FlashAttnPatternOutscaleNoMask(bool softmax_with_cast)
: softmax_with_cast_(softmax_with_cast) {}

public:
std::string name() const override { return "FlashAttnPatternOutscaleNoMask"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern src = ctx->SourcePattern();
// check the transpose,
// q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale
const auto &transpose_q = src.Op("pd_op.transpose");
src.Tensor("q_transpose_out") = transpose_q(src.Tensor("q"));
// k[b, s, head, head_dim] -> transpose -> k[b, head, s, head_dim]
const auto &transpose_k = src.Op("pd_op.transpose");
src.Tensor("k_transpose_out") = transpose_k(src.Tensor("k"));
// v[b, s, head, head_dim] -> transpose -> v[b, head, s, head_dim]
const auto &transpose_v = src.Op("pd_op.transpose");
src.Tensor("v_transpose_out") = transpose_v(src.Tensor("v"));
// qk
const auto &qk_matmul =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_qk_transpose_x")},
{"transpose_y", src.Attr("matmul_qk_transpose_y")}});
src.Tensor("qk_out") =
qk_matmul(src.Tensor("q_transpose_out"), src.Tensor("k_transpose_out"));
const auto &scale_out = src.Op("pd_op.scale");
const auto &full_scale =
src.Op("pd_op.full", {{"value", src.Attr("scale_out_value")}});
src.Tensor("qk_scale_out") = scale_out(src.Tensor("qk_out"), full_scale());

if (softmax_with_cast_) {
// cast + softmax + cast
const auto &softmax_cast1 = src.Op("pd_op.cast");
src.Tensor("softmax_cast1_out") =
softmax_cast1(src.Tensor("qk_scale_out"));
const auto &softmax =
src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}});
src.Tensor("softmax_cast2_in") = softmax(src.Tensor("softmax_cast1_out"));
const auto &softmax_cast2 = src.Op("pd_op.cast");
src.Tensor("softmax_out") = softmax_cast2(src.Tensor("softmax_cast2_in"));
} else {
// softmax
const auto &softmax =
src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}});
src.Tensor("softmax_out") = softmax(src.Tensor("qk_scale_out"));
}

// o
const auto &context_matmul =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("context_matmul_transpose_x")},
{"transpose_y", src.Attr("context_matmul_transpose_y")}});
src.Tensor("context_matmul_out") = context_matmul(
src.Tensor("softmax_out"), src.Tensor("v_transpose_out"));
const auto &o_transpose = src.Op("pd_op.transpose");
src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out"));

// Constraints
src.AddConstraint([](const paddle::drr::MatchContext &match_ctx) -> bool {
auto q_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("q"));
if (!q_dtype.isa<pir::Float16Type>() &&
!q_dtype.isa<pir::BFloat16Type>()) {
return false;
}
// softmax
const auto &softmax_axis = match_ctx.Attr<int>("softmax_axis");
if (softmax_axis != -1 && softmax_axis != 3) return false;
// matmul transpose
bool matmul_qk_transpose_x =
match_ctx.Attr<bool>("matmul_qk_transpose_x");
bool matmul_qk_transpose_y =
match_ctx.Attr<bool>("matmul_qk_transpose_y");
if (matmul_qk_transpose_x || !matmul_qk_transpose_y) return false;

bool matmul_o_transpose_x =
match_ctx.Attr<bool>("context_matmul_transpose_x");
bool matmul_o_transpose_y =
match_ctx.Attr<bool>("context_matmul_transpose_y");
if (matmul_o_transpose_x || matmul_o_transpose_y) return false;
// tensor shape
auto q_transpose_out =
pir::GetShapeFromValue(match_ctx.Tensor("q_transpose_out"));
auto k_transpose_out =
pir::GetShapeFromValue(match_ctx.Tensor("k_transpose_out"));
auto v_transpose_out =
pir::GetShapeFromValue(match_ctx.Tensor("v_transpose_out"));
if (q_transpose_out.size() != 4 || k_transpose_out.size() != 4 ||
v_transpose_out.size() != 4 ||
!(q_transpose_out.at(0) == k_transpose_out.at(0) &&
k_transpose_out.at(0) == v_transpose_out.at(0)) ||
!(q_transpose_out.at(1) == k_transpose_out.at(1) &&
k_transpose_out.at(1) == v_transpose_out.at(1)) ||
!(q_transpose_out.at(3) == k_transpose_out.at(3) &&
k_transpose_out.at(3) == v_transpose_out.at(3))) {
return false;
}

return true;
});

//
// Result Pattern.
//
paddle::drr::ResultPattern res = src.ResultPattern();
const auto &flash_attn = res.Op("pd_op.flash_attn",
{{{"dropout", res.Float32Attr(0.0)},
{"causal", res.BoolAttr(false)},
{"return_softmax", res.BoolAttr(false)},
{"is_test", res.BoolAttr(true)},
{"rng_name", res.StrAttr("")}}});
flash_attn({&res.Tensor("q"),
&res.Tensor("k"),
&res.Tensor("v"),
&res.InputNoneTensor(),
&res.InputNoneTensor()},
{&res.Tensor("out"),
&res.Tensor("softmax"),
&res.Tensor("softmax_lse"),
&res.Tensor("seed_offset")});
}
};

// slice qkv
class TransposeSliceFlashAttnPattern : public paddle::drr::DrrPatternBase {
public:
Expand Down Expand Up @@ -466,13 +602,25 @@ class FusedFlashAttnPass : public pir::PatternRewritePass {

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<FlashAttnPatternQscale>(context, true));
ps.Add(paddle::drr::Create<FlashAttnPatternQscale>(context, false));
ps.Add(paddle::drr::Create<FlashAttnPatternOutscale>(context, true));
ps.Add(paddle::drr::Create<FlashAttnPatternOutscale>(context, false));
ps.Add(paddle::drr::Create<FlashAttnPatternQscaleWithMask>(context, true));
ps.Add(paddle::drr::Create<FlashAttnPatternQscaleWithMask>(context, false));
ps.Add(
paddle::drr::Create<FlashAttnPatternOutscaleWithMask>(context, true));
ps.Add(
paddle::drr::Create<FlashAttnPatternOutscaleWithMask>(context, false));
ps.Add(paddle::drr::Create<FlashAttnPatternOutscaleNoMask>(context, true));
ps.Add(paddle::drr::Create<FlashAttnPatternOutscaleNoMask>(context, false));
ps.Add(paddle::drr::Create<TransposeSliceFlashAttnPattern>(context));
return ps;
}

bool CanApplyOn(pir::Operation *op) const override {
#ifdef PADDLE_WITH_FLASHATTN
return op->num_regions() > 0;
#else
return false;
#endif
}
};

} // namespace
Expand Down
Loading