Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged Attention alibi support #926

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ impl PagedAttention {
num_key_value_heads: Option<usize>,
sliding_window: Option<usize>,
device: &Device,
alibi_slopes: Option<Vec<f64>>,
alibi_slopes: Option<Vec<f32>>,
) -> Result<Self> {
let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads);
let num_queries_per_kv = num_attention_heads / num_key_value_heads;
let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
assert_eq!(alibi_slopes.len(), head_dim);
Some(Tensor::new(alibi_slopes, device)?)
} else {
None
Expand Down
24 changes: 9 additions & 15 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,10 @@ use crate::{
pipeline::text_models_inputs_processor::PagedAttentionInputMetadata,
};

const _PARTITION_SIZE: usize = 512;

#[allow(dead_code)]
pub struct PagedAttention {
num_attention_heads: usize,
head_dim: usize,
num_key_value_heads: usize,
scale: f32,
sliding_window: Option<usize>,
num_queries_per_kv: usize,
n_kv_groups: usize,
alibi_slopes: Option<Tensor>,
}

Expand All @@ -28,22 +22,20 @@ impl PagedAttention {
num_key_value_heads: Option<usize>,
sliding_window: Option<usize>,
device: &Device,
alibi_slopes: Option<Vec<f64>>,
alibi_slopes: Option<Vec<f32>>,
) -> Result<Self> {
let num_key_value_heads = num_key_value_heads.unwrap_or(num_attention_heads);
let num_queries_per_kv = num_attention_heads / num_key_value_heads;
let n_kv_groups = num_attention_heads / num_key_value_heads;
let alibi_slopes = if let Some(alibi_slopes) = alibi_slopes {
assert_eq!(alibi_slopes.len(), head_dim);
Some(Tensor::new(alibi_slopes, device)?)
} else {
None
};
Ok(Self {
num_attention_heads,
head_dim,
num_key_value_heads,
scale,
sliding_window,
num_queries_per_kv,
n_kv_groups,
alibi_slopes,
})
}
Expand Down Expand Up @@ -81,6 +73,7 @@ impl PagedAttention {
let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
let (_, key_value_heads, _, _) = key.shape().dims4()?;

#[allow(clippy::cast_possible_truncation)]
let att = match attention_mask {
None => None,
Some(mask) => Some(Sdpa.run_attention(
Expand All @@ -90,11 +83,11 @@ impl PagedAttention {
Some(mask),
None,
&SdpaParams {
n_kv_groups: attention_heads / key_value_heads,
n_kv_groups: self.n_kv_groups,
use_flash_attn: false,
softcap: softcapping.map(|x| x as f32),
softmax_scale: self.scale,
sliding_window: None,
sliding_window: self.sliding_window,
},
)?),
};
Expand Down Expand Up @@ -159,6 +152,7 @@ impl PagedAttention {
value_cache.as_ref().unwrap(),
input_metadata.block_tables.as_ref().unwrap(),
input_metadata.context_lens.as_ref().unwrap(),
self.alibi_slopes.as_ref(),
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
Expand Down
19 changes: 19 additions & 0 deletions mistralrs-paged-attn/src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct PagedAttention {
value_cache: Tensor,
block_tables: Tensor,
context_lens: Tensor,
alibi_slopes: Option<Tensor>,
max_context_len: usize,
}

Expand Down Expand Up @@ -101,6 +102,19 @@ impl PagedAttention {
let cl = cl.slice(cl_l.start_offset()..);
let bt = bt.slice(bt_l.start_offset()..);

let alibi_s_ptr = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
let (alibi_s, alibi_s_l) = alibi_slopes.storage_and_layout();
let alibi_s = match &*alibi_s {
Storage::Cuda(alibi_s) => alibi_s,
_ => candle::bail!("context_lens must be a cuda tensor"),
};
let alibi_s = alibi_s.as_cuda_slice::<f32>()?;
let alibi_s = alibi_s.slice(alibi_s_l.start_offset()..);
*alibi_s.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};

let (num_seqs, num_heads, head_size) = q_l.shape().dims3()?;
if !(head_size == 64
|| head_size == 80
Expand Down Expand Up @@ -173,6 +187,7 @@ impl PagedAttention {
q_ptr,
kc_ptr,
vc_ptr,
alibi_s_ptr,
num_kv_heads as c_int,
self.softmax_scale,
self.softcapping,
Expand Down Expand Up @@ -210,6 +225,7 @@ impl PagedAttention {
q_ptr,
kc_ptr,
vc_ptr,
alibi_s_ptr,
num_kv_heads as c_int,
self.softmax_scale,
self.softcapping,
Expand Down Expand Up @@ -270,6 +286,7 @@ impl candle::CustomOp1 for PagedAttention {
/// * `max_context_len` - Max of `context_len`
/// * `softmax_scale` - scaling factor
/// * `softcapping`- Softcapping value as in Gemma 2. Using 1.0 means do nothing.
/// * `alibi_slopes`- Optional alibi slopes, `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(num_sequences, num_heads_q, head_size)`.
#[allow(clippy::too_many_arguments)]
Expand All @@ -279,6 +296,7 @@ pub fn paged_attention(
value_cache: &Tensor,
block_tables: &Tensor,
context_lens: &Tensor,
alibi_slopes: Option<&Tensor>,
max_context_len: usize,
softmax_scale: f32,
softcapping: f32,
Expand All @@ -291,6 +309,7 @@ pub fn paged_attention(
context_lens: context_lens.clone(),
max_context_len,
softcapping,
alibi_slopes: alibi_slopes.cloned(),
};
q.apply_op1(op)
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-paged-attn/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern "C" {
query: *const c_void,
key_cache: *const c_void,
value_cache: *const c_void,
alibi_slopes: *const c_void,
num_kv_heads: c_int,
scale: f32,
softcapping: f32,
Expand Down Expand Up @@ -51,6 +52,7 @@ extern "C" {
query: *const c_void,
key_cache: *const c_void,
value_cache: *const c_void,
alibi_slopes: *const c_void,
num_kv_heads: c_int,
scale: f32,
softcapping: f32,
Expand Down
20 changes: 12 additions & 8 deletions mistralrs-paged-attn/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ __global__ void paged_attention_v2_reduce_kernel(
block_tables, \
context_lens, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
reinterpret_cast<float*>(alibi_slopes), \
q_stride, \
kv_block_stride, \
kv_head_stride);
Expand All @@ -600,6 +600,7 @@ void paged_attention_v1_launcher(
void *query,
void *key_cache,
void *value_cache,
void* __restrict__ alibi_slopes,
int num_kv_heads,
float scale,
float softcapping,
Expand All @@ -619,8 +620,7 @@ void paged_attention_v1_launcher(
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
// assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = nullptr;
// NOTE: alibi_slopes is optional. It may be nullptr.

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
Expand Down Expand Up @@ -666,6 +666,7 @@ void paged_attention_v1_launcher(
query, \
key_cache, \
value_cache, \
alibi_slopes, \
num_kv_heads, \
scale, \
softcapping, \
Expand Down Expand Up @@ -702,7 +703,8 @@ extern "C" void paged_attention_v1(
void *query, // [num_seqs, num_heads, head_size]
void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
void *value_cache, // [num_blocks, num_heads, head_size, block_size]
int32_t num_kv_heads, // [num_heads]
void *alibi_slopes, // [num_heads]
int32_t num_kv_heads,
float scale,
float softcapping,
uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq]
Expand Down Expand Up @@ -740,11 +742,11 @@ extern "C" void paged_attention_v1(
reinterpret_cast<T*>(value_cache), \
num_kv_heads, \
scale, \
softcapping, \
softcapping, \
block_tables, \
context_lens, \
max_num_blocks_per_seq, \
alibi_slopes, \
reinterpret_cast<float*>(alibi_slopes), \
q_stride, \
kv_block_stride, \
kv_head_stride); \
Expand All @@ -770,6 +772,7 @@ void paged_attention_v2_launcher(
void *query,
void *key_cache,
void *value_cache,
void *alibi_slopes,
int num_kv_heads,
float scale,
float softcapping,
Expand All @@ -788,8 +791,7 @@ void paged_attention_v2_launcher(
) {
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes = nullptr;
// NOTE: alibi_slopes is optional. It may be nullptr.

T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out);

Expand Down Expand Up @@ -843,6 +845,7 @@ void paged_attention_v2_launcher(
query, \
key_cache, \
value_cache, \
alibi_slopes, \
num_kv_heads, \
scale, \
softcapping, \
Expand Down Expand Up @@ -882,6 +885,7 @@ extern "C" void paged_attention_v2(
void *query, // [num_seqs, num_heads, head_size]
void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
void *value_cache, // [num_blocks, num_heads, head_size, block_size]
void *alibi_slopes, // [num_heads]
int32_t num_kv_heads,
float scale,
float softcapping,
Expand Down
Loading