Skip to content

Commit

Permalink
Fix compute_kv_shard for world size which is not a multiple of kv heads
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 16, 2025
1 parent 098776f commit d01ae68
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mistralrs-quant/src/distributed/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,11 @@ pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm)
let kv_replicate = if comm.world_size() > total_num_kv_heads {
comm.world_size() / total_num_kv_heads
} else {
1
return Shard::Simple {
dim: 0,
rank: comm.rank(),
world_size: comm.world_size(),
};
};

let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
Expand Down

0 comments on commit d01ae68

Please sign in to comment.