From 92cfd773aa147208c116fbcaeb05147f57ed3e04 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 16 Feb 2025 13:24:14 -0500 Subject: [PATCH] FlashAttention V2/V3 metadata with support for device location (#1148) * Flash attn metadata with support for device location * Actually cast --- mistralrs-core/src/attention.rs | 6 +++ .../src/pipeline/inputs_processor.rs | 41 +++++++++++-------- mistralrs-core/src/pipeline/normal.rs | 4 +- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index 5011231b3c..c995036d22 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -40,6 +40,9 @@ fn flash_attn( let window_size_left = sdpa_params.sliding_window; let window_size_right = if causal { Some(0) } else { None }; + let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()]; + let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()]; + candle_flash_attn::flash_attn_varlen_windowed_softcap( &q, &k, @@ -94,6 +97,9 @@ fn flash_attn( let window_size_left = sdpa_params.sliding_window; let window_size_right = if causal { Some(0) } else { None }; + let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()]; + let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()]; + candle_flash_attn_v3::flash_attn_varlen_windowed( &q, &k, diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index e2eb5c46a1..eb6ed1169a 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -120,19 +120,8 @@ pub mod text_models_inputs_processor { pub struct FlashParams { pub max_q: u32, pub max_k: u32, - pub cumulative_seqlens_q: Tensor, - pub cumulative_seqlens_k: Tensor, - } - - impl FlashParams { - pub fn to_device(&self, device: &Device) -> candle_core::Result { - Ok(Self { - max_k: self.max_k, - max_q: self.max_q, - cumulative_seqlens_k: self.cumulative_seqlens_k.to_device(device)?, - cumulative_seqlens_q: self.cumulative_seqlens_q.to_device(device)?, - }) - } + pub cumulative_seqlens_q: HashMap, + pub cumulative_seqlens_k: HashMap, } pub struct InputMetadata { @@ -268,6 +257,15 @@ pub mod text_models_inputs_processor { .cumsum(0)? .to_dtype(DType::U32)?; + let mut seqlens_q_map = HashMap::new(); + let mut seqlens_k_map = HashMap::new(); + + let devices = mapper.unwrap().get_unique_devices(); + for device in devices { + seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?); + seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?); + } + let input = Tensor::cat(&seqs_tensors, 0).unwrap(); // Only use matmul via f16 if prompt and seqlen > 512 if input.dim(1)? > VIA_F16_TOK_THRESHOLD { @@ -345,8 +343,8 @@ pub mod text_models_inputs_processor { flash_meta: FlashParams { max_k, max_q, - cumulative_seqlens_k: seqlens_k, - cumulative_seqlens_q: seqlens_q, + cumulative_seqlens_k: seqlens_k_map, + cumulative_seqlens_q: seqlens_q_map, }, }) } @@ -438,6 +436,15 @@ pub mod text_models_inputs_processor { .cumsum(0)? .to_dtype(DType::U32)?; + let mut seqlens_q_map = HashMap::new(); + let mut seqlens_k_map = HashMap::new(); + + let devices = mapper.unwrap().get_unique_devices(); + for device in devices { + seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?); + seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?); + } + set_use_matmul_via_f16(false); let paged_attn_meta = if paged_attn_metadata.is_some() { @@ -502,8 +509,8 @@ pub mod text_models_inputs_processor { flash_meta: FlashParams { max_k, max_q, - cumulative_seqlens_k: seqlens_k, - cumulative_seqlens_q: seqlens_q, + cumulative_seqlens_k: seqlens_k_map, + cumulative_seqlens_q: seqlens_q_map, }, }) } diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index ac4ace111f..df994dad87 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -860,7 +860,7 @@ impl Loader for NormalLoader { inputs.context_lens.clone(), inputs.position_ids.clone(), None, - &inputs.flash_meta.to_device(model.device())?, + &inputs.flash_meta.clone(), ) }) .collect::>>()?; @@ -1314,7 +1314,7 @@ impl Pipeline for NormalPipeline { let seqlen_offsets = seqlen_offsets.clone(); let context_lens = context_lens.clone(); let position_ids = position_ids.clone(); - let flash_meta = flash_meta.to_device(model.device())?; + let flash_meta = flash_meta.clone(); handles.push(std::thread::spawn(move || { #[cfg(feature = "cuda")]