Skip to content

Commit

Permalink
Add an NCCL feature flag (#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Feb 11, 2025
1 parent 844cdc0 commit bd5532c
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/DISTRIBUTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ TP splits the model into shards and benefits from fast single-node interconnects

> Note: In mistral.rs, if NCCL is enabled, then automatic device mapping *will not* be used.
**Important**: To build for NCCL, be sure to add the `nccl` feature flag (for example: `--features nccl,cuda`).

See the following environment variables:

|Name|Function|Usage|
Expand Down
1 change: 1 addition & 0 deletions mistralrs-bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ metal = ["mistralrs-core/metal"]
flash-attn = ["cuda", "mistralrs-core/flash-attn"]
accelerate = ["mistralrs-core/accelerate"]
mkl = ["mistralrs-core/mkl"]
nccl = ["mistralrs-core/nccl"]
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ safetensors.workspace = true
pyo3_macros = ["pyo3"]
cuda = [
"candle-core/cuda",
"candle-core/nccl",
"candle-nn/cuda",
"dep:bindgen_cuda",
"mistralrs-quant/cuda",
Expand All @@ -110,6 +109,7 @@ flash-attn = ["cuda", "dep:candle-flash-attn"]
flash-attn-v3 = ["cuda", "dep:candle-flash-attn-v3"]
accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "mistralrs-quant/accelerate"]
mkl = ["candle-core/mkl", "candle-nn/mkl"]
nccl = ["cuda", "mistralrs-quant/nccl"]

[build-dependencies]
bindgen_cuda = { version = "0.1.5", optional = true }
8 changes: 7 additions & 1 deletion mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ impl Loader for NormalLoader {
let use_nccl = available_devices.iter().all(|dev| dev.is_cuda())
&& available_devices.len() > 1
&& (std::env::var("MISTRALRS_NO_NCCL").is_err()
|| std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"));
|| std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"))
&& cfg!(feature = "nccl");

// If auto, convert to Map if not using nccl
if use_nccl {
Expand Down Expand Up @@ -463,6 +464,11 @@ impl Loader for NormalLoader {
let multi_progress = Arc::new(MultiProgress::new());

let mut parallel_models = if use_nccl {
#[cfg(not(feature = "nccl"))]
warn!(
"NCCL support was included in the build, be sure to build with `--features nccl`."
);

// NCCL case!

let pipeline_parallel_size = std::env::var("MISTRALRS_PIPELINE_PARALLEL")
Expand Down
7 changes: 6 additions & 1 deletion mistralrs-quant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ safetensors.workspace = true
regex.workspace = true

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"]
cuda = [
"candle-core/cuda",
"candle-nn/cuda",
"dep:bindgen_cuda"
]
nccl = ["cuda", "candle-core/nccl"]
metal = ["candle-core/metal", "candle-nn/metal", "dep:metal"]
accelerate = ["candle-core/accelerate", "candle-nn/accelerate"]

Expand Down
4 changes: 2 additions & 2 deletions mistralrs-quant/src/distributed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl BarrierLike for Barrier {
}
}

#[cfg(feature = "cuda")]
#[cfg(all(feature = "cuda", feature = "nccl"))]
mod ops {
use std::{fmt::Debug, ops::Deref, sync::Arc};

Expand Down Expand Up @@ -165,7 +165,7 @@ mod ops {
}
}

#[cfg(not(feature = "cuda"))]
#[cfg(not(all(feature = "cuda", feature = "nccl")))]
mod ops {
use std::sync::Arc;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ metal = ["mistralrs-core/metal"]
flash-attn = ["cuda", "mistralrs-core/flash-attn"]
accelerate = ["mistralrs-core/accelerate"]
mkl = ["mistralrs-core/mkl"]
nccl = ["mistralrs-core/nccl"]
1 change: 1 addition & 0 deletions mistralrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ metal = ["mistralrs-core/metal"]
flash-attn = ["cuda", "mistralrs-core/flash-attn"]
accelerate = ["mistralrs-core/accelerate"]
mkl = ["mistralrs-core/mkl"]
nccl = ["mistralrs-core/nccl"]

[[example]]
name = "simple"
Expand Down

0 comments on commit bd5532c

Please sign in to comment.