Skip to content

Commit

Permalink
Multi-node support for tensor parallelism (#1125)
Browse files Browse the repository at this point in the history
* Multi-node support, server?

* Fix check

* Remove check

* Remove check

* Debug

* Debug

* counter+=1

* counter+=1

* Replicate kv heads?

* Replicate kv heads?

* Replicate kv heads?

* Shard with specific offset

* Shard with specific offset

* Replicate kv heads

* Fix num kv groups

* Fix num kv groups

* Debug

* Refactor to client and server

* Use multi-node synchronization

* Debugging

* Hierarchical synchronization design

* It works!

* Add some docs

* Add logging

* Add logging

* Update docs

* Try to reuse socket

* Try to reuse socket

* Maybe its faster

* Minimize tcp traffic

* Clippy

* Set some timeouts
  • Loading branch information
EricLBuehler authored Feb 11, 2025
1 parent 3fb29cc commit 844cdc0
Show file tree
Hide file tree
Showing 9 changed files with 622 additions and 201 deletions.
160 changes: 80 additions & 80 deletions Cargo.lock

Large diffs are not rendered by default.

39 changes: 38 additions & 1 deletion docs/DISTRIBUTED.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Distributed inference in mistral.rs
# Distributed inference in mistral.rs: Tensor parallelism and Multi-node support

Mistral.rs supports distributed inference on CUDA with Tensor Parallelism via NCCL.

Expand All @@ -16,3 +16,40 @@ See the following environment variables:
|--|--|--|
|`MISTRALRS_NO_NCCL=1`|Disable TP and NCCL|If the model does not fit on the available CUDA devices, disabling NCCL will re-enable automatic device mapping|
|`MISTRALRS_PIPELINE_PARALLEL=<number> (default: 1 = disabled)`|Parallelize the model along the layers in addition to the GPUs|Increasing this value is useful for tuning performance on a model-specific basis. It does not change the number of GPUs required, but can help when the single-node interconnects are a bottleneck.|

## Multi-node support

```
# Head node:
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_HEAD_NUM_WORKERS=1 MISTRALRS_MN_HEAD_PORT=<PORT> cargo run --release --features cuda -- -i plain -m ...
# For the worker nodes:
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=0 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=1 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=2 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
```

Multi-node support in mistral.rs divides the nodes into two groups: a "head" node, and multiple "worker" nodes. Head node choice is arbitrary.
For example, if a system has 8 nodes, there will be 1 "head" node, and 7 "worker" nodes.

To enable multi-node, set the `MISTRALRS_MN_GLOBAL_WORLD_SIZE=<number>` environment variable to the total number of GPUs in all nodes, including "head" and "worker"s.

> Note: `MISTRALRS_PIPELINE_PARALLEL` is incompatible with multi-node (setting `MISTRALRS_MN_GLOBAL_WORLD_SIZE`)
It is recommended to use server mode with mistral.rs when in multi-node. **Currently, you must send requests to every node!**

The following environment variables must be set for each node:

**Head node:**

|Name|Function|Usage|
|--|--|--|
|`MISTRALRS_MN_HEAD_NUM_WORKERS=<number>`|The number of worker nodes which will be connected.|This should be the number of nodes in the system, minus 1 for the head node.|
|`MISTRALRS_MN_HEAD_PORT=<PORT>`|The port on which to communicate with the worker nodes.|Worker nodes will connect to this port via TCP sockets|

**Worker node:**

|Name|Function|Usage|
|--|--|--|
|`MISTRALRS_MN_WORKER_ID=<number>`|The 0-indexed worker ID for this worker node.|If there are 4 nodes (1 head, 3 workers), then the worker ids will be 0, 1, and 2|
|`MISTRALRS_MN_WORKER_SERVER_ADDR=<ADDR>:<PORT>`|The IP address and port to connect to the server.|This is used to establish communication with the head node.|
31 changes: 27 additions & 4 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer, Shard,
ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -193,20 +193,39 @@ impl CausalSelfAttention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(

// We may need to replicate the kv heads
let kv_replicate = if comm.world_size() > cfg.num_key_value_heads {
comm.world_size() / cfg.num_key_value_heads
} else {
1
};

let kv_shard_id = comm.rank() / kv_replicate;
// let kv_block_size = size_kv / comm.world_size();
let kv_block_size = cfg.hidden_size / cfg.num_attention_heads;
let shard = Shard::Offset {
dim: 0,
offset: kv_shard_id * kv_block_size,
len: kv_block_size,
};

let k_proj = ColumnParallelLayer::new_with_shard(
size_in,
size_kv,
&cfg.quantization_config,
false,
comm,
shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
size_in,
size_kv,
&cfg.quantization_config,
false,
comm,
shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -229,7 +248,11 @@ impl CausalSelfAttention {
max_seq_len: cfg.max_position_embeddings,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
n_kv_groups: if kv_replicate != 1 {
(cfg.num_attention_heads / cfg.num_key_value_heads) / kv_replicate
} else {
cfg.num_attention_heads / cfg.num_key_value_heads
},
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
Expand Down
103 changes: 93 additions & 10 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
};
use anyhow::Result;
use anyhow::{Context, Result};
use candle_core::{Device, Tensor, Var};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use indicatif::MultiProgress;
Expand All @@ -51,12 +51,12 @@ use rayon::iter::{
use regex_automata::meta::Regex;
use std::any::Any;
use std::borrow::Cow;
use std::fs;
use std::num::{NonZero, NonZeroUsize};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{Arc, Barrier, RwLock};
use std::time::Instant;
use std::{env, fs};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tracing::{info, warn};
Expand Down Expand Up @@ -300,7 +300,8 @@ impl Loader for NormalLoader {

let use_nccl = available_devices.iter().all(|dev| dev.is_cuda())
&& available_devices.len() > 1
&& std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1");
&& (std::env::var("MISTRALRS_NO_NCCL").is_err()
|| std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"));

// If auto, convert to Map if not using nccl
if use_nccl {
Expand Down Expand Up @@ -476,15 +477,63 @@ impl Loader for NormalLoader {
anyhow::bail!("MISTRALRS_PIPELINE_PARALLEL must be nonzero")
}

let world_size = available_devices.len() / pipeline_parallel_size;
let local_world_size = available_devices.len() / pipeline_parallel_size;
let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
} else {
local_world_size
};

let use_multi_node = global_world_size != local_world_size;
if use_multi_node {
info!("Global world size != local world size, entering multi-node.");
}

info!("Tensor parallel world size is {world_size}");
if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
}

info!("Local tensor parallel world size is {local_world_size}");
info!("Global tensor parallel world size is {global_world_size}");
info!("Pipeline parallelism size is {pipeline_parallel_size}");

let ids = (0..pipeline_parallel_size)
let mut ids = (0..pipeline_parallel_size)
.map(|_| mistralrs_quant::Id::new())
.collect::<Vec<_>>();

if ids.len() != 1 {
anyhow::bail!(
"MISTRALRS_PIPELINE_PARALLEL cannot be set at the same time as MISTRALRS_MN_GLOBAL_WORLD_SIZE; multi-node is incompatible with pipeline parallel."
);
}

if use_multi_node {
let id = &mut ids[0];
if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
let n_nodes =
usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
info!("Head node managing {n_nodes} workers.");
let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
anyhow::bail!(
"Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT"
);
};
info!("Head node initializing connection on {port}.");
let server = mistralrs_quant::Server::new(
&format!("0.0.0.0:{port}"),
n_nodes,
local_world_size,
)?;

server.broadcast_id(id)?;
} else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
info!("Worker node connecting to {addr}.");
let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;

*id = client.receive_id()?;
}
}

if available_devices.len() % ids.len() != 0 {
anyhow::bail!(
"Pipeline parallel size {} must divide the number of available devices {}",
Expand All @@ -497,13 +546,47 @@ impl Loader for NormalLoader {
.chunks(available_devices.len() / pipeline_parallel_size)
.collect::<Vec<_>>();

let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
anyhow::bail!(
"Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID"
);
};
let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
info!("Worker ID is {node_id}.");
(node_id + 1) * local_world_size
} else {
0
};

// Transpose
let mut comms_all = Vec::new();
for (pipeline_parallel_i, devices_per_pipeline_parallel) in
split_available_devices.iter().enumerate()
{
// Each pipeline parallel gets its own barrier
let barrier = Arc::new(Barrier::new(world_size));
let barrier = if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
let n_nodes =
usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
anyhow::bail!(
"Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT"
);
};
let server = mistralrs_quant::Server::new(
&format!("0.0.0.0:{port}"),
n_nodes,
local_world_size,
)?;

Arc::new(server) as Arc<dyn mistralrs_quant::BarrierLike>
} else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
Arc::new(client) as Arc<dyn mistralrs_quant::BarrierLike>
} else {
Arc::new(Barrier::new(local_world_size))
as Arc<dyn mistralrs_quant::BarrierLike>
};

// They each block on each other
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
Expand All @@ -522,8 +605,8 @@ impl Loader for NormalLoader {
mistralrs_quant::Comm::from_device(
ids[pipeline_parallel_i],
device,
rank,
world_size,
rank + rank_offset,
global_world_size,
barrier.clone(),
)
})
Expand All @@ -539,7 +622,7 @@ impl Loader for NormalLoader {

// row major: number of ranks x pipeline parallel
// Also corresponds to the device for that comm for the
let comms = (0..world_size)
let comms = (0..local_world_size)
.map(|pipeline_parallel_i| {
comms_all
.iter()
Expand Down
25 changes: 19 additions & 6 deletions mistralrs-quant/src/distributed/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
};

fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
Shard {
Shard::Simple {
dim,
rank,
world_size,
Expand Down Expand Up @@ -195,18 +195,15 @@ pub struct ColumnParallelLayer {

impl ColumnParallelLayer {
#[allow(clippy::new_ret_no_self)]
pub fn new(
pub fn new_with_shard(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
bias: bool,
comm: &Arc<crate::Comm>,
shard: Shard,
vb: ShardedVarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let rank = comm.rank();
let world_size = comm.world_size();
let shard = shard(0, rank, world_size);

let weight = if let Some(quant_conf) = &config {
// GPTQ and BNB do not support tensor parallelism
if matches!(
Expand Down Expand Up @@ -259,6 +256,22 @@ impl ColumnParallelLayer {

Ok(Arc::new(Self { weight, bias }))
}

#[allow(clippy::new_ret_no_self)]
pub fn new(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
bias: bool,
comm: &Arc<crate::Comm>,
vb: ShardedVarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let rank = comm.rank();
let world_size = comm.world_size();
let shard = shard(0, rank, world_size);

Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
}
}

impl QuantMethod for ColumnParallelLayer {
Expand Down
Loading

0 comments on commit 844cdc0

Please sign in to comment.