Skip to content

Commit

Permalink
FlashAttention V2/V3 metadata with support for device location (#1148)
Browse files Browse the repository at this point in the history
* Flash attn metadata with support for device location

* Actually cast
  • Loading branch information
EricLBuehler authored Feb 16, 2025
1 parent d01ae68 commit 92cfd77
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
6 changes: 6 additions & 0 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ fn flash_attn(
let window_size_left = sdpa_params.sliding_window;
let window_size_right = if causal { Some(0) } else { None };

let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];

candle_flash_attn::flash_attn_varlen_windowed_softcap(
&q,
&k,
Expand Down Expand Up @@ -94,6 +97,9 @@ fn flash_attn(
let window_size_left = sdpa_params.sliding_window;
let window_size_right = if causal { Some(0) } else { None };

let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];

candle_flash_attn_v3::flash_attn_varlen_windowed(
&q,
&k,
Expand Down
41 changes: 24 additions & 17 deletions mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,8 @@ pub mod text_models_inputs_processor {
pub struct FlashParams {
pub max_q: u32,
pub max_k: u32,
pub cumulative_seqlens_q: Tensor,
pub cumulative_seqlens_k: Tensor,
}

impl FlashParams {
pub fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
Ok(Self {
max_k: self.max_k,
max_q: self.max_q,
cumulative_seqlens_k: self.cumulative_seqlens_k.to_device(device)?,
cumulative_seqlens_q: self.cumulative_seqlens_q.to_device(device)?,
})
}
pub cumulative_seqlens_q: HashMap<DeviceLocation, Tensor>,
pub cumulative_seqlens_k: HashMap<DeviceLocation, Tensor>,
}

pub struct InputMetadata {
Expand Down Expand Up @@ -268,6 +257,15 @@ pub mod text_models_inputs_processor {
.cumsum(0)?
.to_dtype(DType::U32)?;

let mut seqlens_q_map = HashMap::new();
let mut seqlens_k_map = HashMap::new();

let devices = mapper.unwrap().get_unique_devices();
for device in devices {
seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
}

let input = Tensor::cat(&seqs_tensors, 0).unwrap();
// Only use matmul via f16 if prompt and seqlen > 512
if input.dim(1)? > VIA_F16_TOK_THRESHOLD {
Expand Down Expand Up @@ -345,8 +343,8 @@ pub mod text_models_inputs_processor {
flash_meta: FlashParams {
max_k,
max_q,
cumulative_seqlens_k: seqlens_k,
cumulative_seqlens_q: seqlens_q,
cumulative_seqlens_k: seqlens_k_map,
cumulative_seqlens_q: seqlens_q_map,
},
})
}
Expand Down Expand Up @@ -438,6 +436,15 @@ pub mod text_models_inputs_processor {
.cumsum(0)?
.to_dtype(DType::U32)?;

let mut seqlens_q_map = HashMap::new();
let mut seqlens_k_map = HashMap::new();

let devices = mapper.unwrap().get_unique_devices();
for device in devices {
seqlens_q_map.insert(device.location(), seqlens_q.to_device(&device)?);
seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
}

set_use_matmul_via_f16(false);

let paged_attn_meta = if paged_attn_metadata.is_some() {
Expand Down Expand Up @@ -502,8 +509,8 @@ pub mod text_models_inputs_processor {
flash_meta: FlashParams {
max_k,
max_q,
cumulative_seqlens_k: seqlens_k,
cumulative_seqlens_q: seqlens_q,
cumulative_seqlens_k: seqlens_k_map,
cumulative_seqlens_q: seqlens_q_map,
},
})
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ impl Loader for NormalLoader {
inputs.context_lens.clone(),
inputs.position_ids.clone(),
None,
&inputs.flash_meta.to_device(model.device())?,
&inputs.flash_meta.clone(),
)
})
.collect::<candle_core::Result<Vec<_>>>()?;
Expand Down Expand Up @@ -1314,7 +1314,7 @@ impl Pipeline for NormalPipeline {
let seqlen_offsets = seqlen_offsets.clone();
let context_lens = context_lens.clone();
let position_ids = position_ids.clone();
let flash_meta = flash_meta.to_device(model.device())?;
let flash_meta = flash_meta.clone();

handles.push(std::thread::spawn(move || {
#[cfg(feature = "cuda")]
Expand Down

0 comments on commit 92cfd77

Please sign in to comment.