Skip to content

Commit

Permalink
Adds small-batch kernels (#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
yjk21 authored Jul 17, 2021
1 parent c1378e6 commit ed71996
Show file tree
Hide file tree
Showing 15 changed files with 1,584 additions and 200 deletions.
217 changes: 172 additions & 45 deletions apex/contrib/csrc/fmha/fmha_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,6 @@

#include "fmha.h"

void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);

void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);

void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
Expand All @@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *seqlens_d,
void *o_packed_d,
void *s_d,
float p_dropout) {
Expand All @@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);

params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
params.seqlens = static_cast<int *>(seqlens_d);

// S = softmax(P)
params.s_ptr = s_d;
Expand Down Expand Up @@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}

constexpr uint32_t NUM_HEADS_DIM = 2;
constexpr uint32_t THREE_DIM = 1;

std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout,
const int max_seq_len,
const bool is_training,
Expand Down Expand Up @@ -149,28 +121,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \

TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);

TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())

TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(seqlens.is_contiguous())

TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);

const auto sizes = qkv.sizes();

TORCH_CHECK(sizes[THREE_DIM] == 3);

const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int total = sizes[0];
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
Expand All @@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
Expand All @@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
Expand Down Expand Up @@ -247,27 +213,23 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);

TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());

TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(seqlens.is_contiguous());

TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);

const auto sizes = qkv.sizes();

TORCH_CHECK(sizes[THREE_DIM] == 3);

const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);

Expand All @@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);

// we're re-using these scales scales
// we're re-using these scales
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
Expand All @@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return { dqkv, softmax };
}

std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
TORCH_CHECK(max_seq_len == seq_len);

constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const int elts_per_thread = 8 * mmas_m * mmas_n;

auto stream = at::cuda::getCurrentCUDAStream().stream();

TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())

TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())

TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);

const auto sizes = qkv.sizes();

TORCH_CHECK(sizes[THREE_DIM] == 3);

const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();

auto ctx = torch::empty({ total, num_heads, head_size }, opts);

auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());

Fused_multihead_attention_fprop_params params;

set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;

if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
int num_chunks = 3;
if(batch_size == 3) {
num_chunks = 2;
}

launch(params, is_training, num_chunks, stream);

return { ctx, s };
}

std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {

auto stream = at::cuda::getCurrentCUDAStream().stream();

TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())

TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())

TORCH_CHECK(cu_seqlens.dim() == 1);

TORCH_CHECK(qkv.dim() == 4);

const auto sizes = qkv.sizes();

TORCH_CHECK(sizes[THREE_DIM] == 3);

const int batch_size = cu_seqlens.numel() - 1;

const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);

int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;

auto opts = qkv.options();

auto dqkv = torch::empty_like(qkv);

int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
}else if( batch_size == 2 ) {
num_chunks = 3;
}
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);

Fused_multihead_attention_fprop_params params;

set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // o_ptr = dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);

params.dkv_ptr = dkv.data_ptr();

Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();

launch(params, num_chunks, stream);

//SPLIT-K reduction of num_chunks dK, dV parts

// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);

const int hidden_size = num_heads * head_size;
fmha_run_noloop_reduce(
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);

return { dqkv, softmax, dkv };
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
Loading

0 comments on commit ed71996

Please sign in to comment.