Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dispatch #38

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ff203ce
compilable poc
tenpercent Dec 4, 2024
6e9c4d3
modify codegen and regenerate the instances to encode dispatch cases …
tenpercent Dec 4, 2024
adfcf82
encode head dimension dispatch switch as type
tenpercent Dec 4, 2024
f28928a
enable offload compression
tenpercent Dec 5, 2024
52c4042
clang-format
tenpercent Dec 5, 2024
102a5e3
fix python lints
tenpercent Dec 5, 2024
bf52044
run black
tenpercent Dec 5, 2024
a26f1c1
run isort
tenpercent Dec 5, 2024
85dc1c0
move out common struct
tenpercent Dec 5, 2024
ec724b8
add missing header
tenpercent Dec 5, 2024
2f2dca8
adjust number of compilation workers to avoid ooms in CI
tenpercent Dec 5, 2024
aa2657a
roll back run-clang-format
tenpercent Dec 5, 2024
7222125
sync wheels_build with facebookresearch
tenpercent Dec 6, 2024
47d80c9
add missing script for new wheel build
tenpercent Dec 6, 2024
de21a1d
use separate runner for wheels and ci
tenpercent Dec 6, 2024
a2ba3e7
lower max_jobs for ci build to 80
tenpercent Dec 6, 2024
69a92fe
sync .github/actions with facebookresearch
tenpercent Dec 6, 2024
3dad112
add rocm 6.2 for wheel build
tenpercent Dec 6, 2024
8cf45ca
bump pytorch to 2.5.1
tenpercent Dec 6, 2024
d576b53
do not offload compress prior to rocm 6.2
tenpercent Dec 6, 2024
b58b5de
amend version check
tenpercent Dec 6, 2024
1a25871
try limiting container memory to avoid oom in ci
tenpercent Dec 6, 2024
f16b48a
tweak memory limit for rocm_ci
tenpercent Dec 9, 2024
698506d
Merge remote-tracking branch 'origin/develop' into refactor-dispatch
tenpercent Dec 9, 2024
32ad0c0
modify MAX_JOBS
tenpercent Dec 10, 2024
1d36a2a
bump torch wheel to rocm+6.2 stable
tenpercent Dec 10, 2024
070897c
Merge remote-tracking branch 'origin/develop' into refactor-dispatch
tenpercent Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,44 @@

#include "ck_tiled_fmha_batched_infer_dispatch.h"
#include "ck_tiled_fmha_batched_infer_splitkv_dispatch.h"
#include "ck_tiled_fmha_dispatch_tags.h"
#include "ck_tiled_fmha_seqlen_q_switch.h"

template <
typename ScalarType,
bool kHasMask,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
typename HasMask,
typename HasBias,
typename HasDropout,
typename MaxHeadDimension>
void run_batched_infer_mask_bias_dropout_dispatch(
BatchedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
if constexpr (!kHasDropout) {
if constexpr (!HasDropout::value) {
#ifndef FMHA_FWD_SPLITKV_NOT_USED
if (param.use_split_kv) {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, kMaxSeqlenQ, [&] {
batched_infer_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
HasMask,
HasBias,
MaxHeadDimension,
max_query_seqlen_t<kMaxSeqlenQ>>::Run(param, stream);
});
} else
#endif
batched_infer_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
HasMask,
HasBias,
HasDropout,
MaxHeadDimension>::Run(param, stream);
} else {
batched_infer_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
HasMask,
HasBias,
HasDropout,
MaxHeadDimension>::Run(param, stream);
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
void batched_infer_bf16(BatchedForwardParams& param, hipStream_t stream) {
const bool has_dropout = (param.dropout_prob > 0.0f);
BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, kMaxHeadDimension, [&] {
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_batched_infer_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
false,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<false>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_batched_infer_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
true,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<true>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else
throw std::runtime_error("Invalid custom_mask_type value");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

template <
typename ScalarType,
bool kHasMask,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
typename HasMask,
typename HasBias,
typename HasDropout,
typename MaxHeadDimension>
struct batched_infer_mask_bias_dropout_dispatch {
template <typename FmhaTraits, typename FmhaMask>
using FmhaPipelineProblemTemp = ck_tile::BlockFmhaPipelineProblem<
Expand All @@ -37,20 +37,21 @@ struct batched_infer_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::PDataType,
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
typename FmhaFwdTypeConfig<ScalarType>::ODataType,
FmhaFwdShape<MaxK>,
FmhaFwdShape<MaxHeadDimension::value>,
false, // kIsGroupMode
FmhaMask,
FmhaTraits>;

static void Run(BatchedForwardParams& param, hipStream_t stream) {
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<HasMask::value>;

using FmhaShape = FmhaFwdShape<MaxK>;
using FmhaShape = FmhaFwdShape<MaxHeadDimension::value>;
using FmhaTilePartitioner = ck_tile::FmhaFwdTilePartitioner<FmhaShape>;
constexpr ck_tile::index_t occupancy =
(MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2);
constexpr ck_tile::index_t occupancy = (MaxHeadDimension::value == 64)
? 3
: ((MaxHeadDimension::value == 256) ? 1 : 2);

constexpr auto kBiasEnum = kHasBias
constexpr auto kBiasEnum = HasBias::value
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
: ck_tile::BlockAttentionBiasEnum::NO_BIAS;

Expand All @@ -65,8 +66,8 @@ struct batched_infer_mask_bias_dropout_dispatch {
const bool pad_headdim = (pad_headdim_q || pad_headdim_v);

const bool use_async_pipeline =
(!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) &&
(MaxK <= 128));
(!HasBias::value && (param.K % 8 == 0) && (param.Kv % 8 == 0) &&
(MaxHeadDimension::value <= 128));

if (!use_async_pipeline) {
BOOL_SWITCH_3(
Expand All @@ -85,7 +86,7 @@ struct batched_infer_mask_bias_dropout_dispatch {
kBiasEnum,
false, // kHasBiasGrad place-holder
false, // kStoreLSE
kHasDropout,
HasDropout::value,
false, // kDoFp8StaticQuant place-holder
occupancy>;

Expand Down Expand Up @@ -117,7 +118,7 @@ struct batched_infer_mask_bias_dropout_dispatch {
kBiasEnum,
false, // kHasBiasGrad place-holder
false, // kStoreLSE
kHasDropout,
HasDropout::value,
false, // kDoFp8StaticQuant place-holder
occupancy>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream) {
const bool has_dropout = (param.dropout_prob > 0.0f);
BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, kMaxHeadDimension, [&] {
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_batched_infer_mask_bias_dropout_dispatch<
ck_tile::fp16_t,
false,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<false>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_batched_infer_mask_bias_dropout_dispatch<
ck_tile::fp16_t,
true,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<true>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else
throw std::runtime_error("Invalid custom_mask_type value");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

template <
typename ScalarType,
bool kHasMask,
bool kHasBias,
ck_tile::index_t MaxK,
ck_tile::index_t MaxSeqlenQ>
typename HasMask,
typename HasBias,
typename MaxHeadDimension,
typename MaxSeqlenQ>
struct batched_infer_splitkv_mask_bias_dropout_dispatch {
template <
typename FmhaFwdSplitKVTraits,
Expand All @@ -40,7 +40,9 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::PDataType,
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
ODataType,
typename FmhaFwdSplitKVShape<MaxK, MaxSeqlenQ>::Type,
typename FmhaFwdSplitKVShape<
MaxHeadDimension::value,
MaxSeqlenQ::value>::Type,
false, // kIsGroupMode
FmhaMask,
FmhaFwdSplitKVTraits>;
Expand All @@ -54,23 +56,24 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::LSEDataType,
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
typename FmhaFwdTypeConfig<ScalarType>::ODataType,
MaxK, // headdim_v
MaxHeadDimension::value, // headdim_v
kM0,
kN1,
false, // kIsGroupMode
FmhaSplitKVCombineTraits>;

static void Run(BatchedForwardParams& param, hipStream_t stream) {
{
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<HasMask::value>;

using FmhaTileShape =
typename FmhaFwdSplitKVShape<MaxK, MaxSeqlenQ>::Type;
using FmhaTileShape = typename FmhaFwdSplitKVShape<
MaxHeadDimension::value,
MaxSeqlenQ::value>::Type;
using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVTilePartitioner<FmhaTileShape>;
constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
constexpr auto kBiasEnum = HasBias::value
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
: ck_tile::BlockAttentionBiasEnum::NO_BIAS;

Expand Down Expand Up @@ -174,8 +177,9 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
};

if (param.num_kv_splits > 1) {
using FmhaTileShape =
typename FmhaFwdSplitKVShape<MaxK, MaxSeqlenQ>::Type;
using FmhaTileShape = typename FmhaFwdSplitKVShape<
MaxHeadDimension::value,
MaxSeqlenQ::value>::Type;

constexpr ck_tile::index_t kM0 = FmhaTileShape::kM0 / 2;
constexpr ck_tile::index_t kN1 = FmhaTileShape::kN1 / 2;
Expand Down
25 changes: 25 additions & 0 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_dispatch_tags.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

#include "ck_tile/core/numeric/integral_constant.hpp"

template <bool v>
struct has_mask_t : ck_tile::bool_constant<v> {};

template <bool v>
struct has_bias_t : ck_tile::bool_constant<v> {};

template <bool v>
struct has_dropout_t : ck_tile::bool_constant<v> {};

template <ck_tile::index_t v>
struct max_head_dimension_t : ck_tile::integral_constant<ck_tile::index_t, v> {
};

template <ck_tile::index_t v>
struct max_query_seqlen_t : ck_tile::integral_constant<ck_tile::index_t, v> {};
37 changes: 19 additions & 18 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,46 @@
*/
#pragma once

#include "ck_tiled_fmha_dispatch_tags.h"
#include "ck_tiled_fmha_grouped_infer_dispatch.h"
#include "ck_tiled_fmha_grouped_infer_splitkv_dispatch.h"
#include "ck_tiled_fmha_seqlen_q_switch.h"

template <
typename ScalarType,
bool kHasMask,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
typename HasMask,
typename HasBias,
typename HasDropout,
typename MaxHeadDimension>
void run_grouped_infer_mask_bias_dropout_dispatch(
GroupedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
if constexpr (!kHasDropout) {
if constexpr (!HasDropout::value) {
#ifndef FMHA_FWD_SPLITKV_NOT_USED
if (param.use_split_kv) {
FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, MaxSeqlenQ, [&] {
FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, kMaxSeqlenQ, [&] {
grouped_infer_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
HasMask,
HasBias,
MaxHeadDimension,
max_query_seqlen_t<kMaxSeqlenQ>>::Run(param, stream);
});
} else
#endif
grouped_infer_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
HasMask,
HasBias,
HasDropout,
MaxHeadDimension>::Run(param, stream);
} else {
grouped_infer_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
HasMask,
HasBias,
HasDropout,
MaxHeadDimension>::Run(param, stream);
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
void grouped_infer_bf16(GroupedForwardParams& param, hipStream_t stream) {
const bool has_dropout = (param.dropout_prob > 0.0f);
BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, kMaxHeadDimension, [&] {
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_grouped_infer_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
false,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<false>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_grouped_infer_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
true,
kHasBias,
kHasDropout,
MaxK>(param, stream);
has_mask_t<true>,
has_bias_t<kHasBias>,
has_dropout_t<kHasDropout>,
max_head_dimension_t<kMaxHeadDimension>>(param, stream);
else
throw std::runtime_error("Invalid custom_mask_type value");
});
Expand Down
Loading
Loading