Skip to content

Commit

Permalink
Fix device map check for paged attn (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Jul 31, 2024
1 parent d970bb5 commit 29f58bc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
7 changes: 6 additions & 1 deletion mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ impl Loader for GGMLLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<GgmlDType>,
paged_attn_config: Option<PagedAttentionConfig>,
mut paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
if in_situ_quant.is_some() {
anyhow::bail!(
Expand All @@ -243,6 +243,11 @@ impl Loader for GGMLLoader {
if !mapper.is_dummy() {
warn!("GGML models do not support device mapping. Device mapping will not work. Please consider using a GGUF model.");
}
if !mapper.is_dummy() && paged_attn_config.is_some() {
warn!("Device mapping and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}

info!(
"Loading model `{}` on {}.",
self.get_id(),
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl Loader for GGUFLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<GgmlDType>,
paged_attn_config: Option<PagedAttentionConfig>,
mut paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
if in_situ_quant.is_some() {
anyhow::bail!(
Expand All @@ -377,6 +377,9 @@ impl Loader for GGUFLoader {
self.get_id(),
device.device_pretty_repr()
);
} else if paged_attn_config.is_some() {
warn!("Device mapping and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ impl Loader for NormalLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<GgmlDType>,
paged_attn_config: Option<PagedAttentionConfig>,
mut paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
let config = std::fs::read_to_string(paths.get_config_filename())?;
let dtype = dtype.try_into_dtype(device)?;
Expand All @@ -236,6 +236,9 @@ impl Loader for NormalLoader {
self.get_id(),
device.device_pretty_repr()
);
} else if paged_attn_config.is_some() {
warn!("Device mapping and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}

info!(
Expand Down
7 changes: 5 additions & 2 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use std::str::FromStr;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tracing::info;
use tracing::{info, warn};

pub struct VisionPipeline {
model: Box<dyn VisionModel + Send + Sync>,
Expand Down Expand Up @@ -157,7 +157,7 @@ impl Loader for VisionLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<GgmlDType>,
paged_attn_config: Option<PagedAttentionConfig>,
mut paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
let config = std::fs::read_to_string(paths.get_config_filename())?;
let dtype = dtype.try_into_dtype(device)?;
Expand All @@ -169,6 +169,9 @@ impl Loader for VisionLoader {
self.get_id(),
device.device_pretty_repr()
);
} else if paged_attn_config.is_some() {
warn!("Device mapping and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}

info!(
Expand Down

0 comments on commit 29f58bc

Please sign in to comment.