Skip to content

Commit

Permalink
Lower memory spike when loading with ISQ on CUDA (#433)
Browse files Browse the repository at this point in the history
* Synchronize device to lower memory usage

* Limit thread count if on cuda

* Add timing, tune thread count, and display info

* Add isq low memory env var
  • Loading branch information
EricLBuehler authored Jun 14, 2024
1 parent 1ae5bfe commit 6648673
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
- Lightweight OpenAI API compatible HTTP server.
- Python API.
- Grammar support with Regex and Yacc.
- [ISQ](docs/ISQ.md) (In situ quantization): run `.safetensors` models directly from Hugging Face Hub by quantizing them after loading instead of creating a GGUF file. This loads the ISQ-able weights on CPU before quantizing with ISQ and then moving back to the device to avoid memory spikes.
- [ISQ](docs/ISQ.md) (In situ quantization): run `.safetensors` models directly from Hugging Face Hub by quantizing them after loading instead of creating a GGUF file.
- This loads the ISQ-able weights on CPU before quantizing with ISQ and then moving to the device to avoid memory spikes.
- Provides methods to further reduce memory spikes.
**Powerful**:
- Fast LoRA support with weight merging.
- First X-LoRA inference platform with first class support.
Expand Down
10 changes: 10 additions & 0 deletions docs/ISQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ If a tensor cannot be quantized, the fallback process is as follows:
1) If using a `K` quant, fallback to a similar `Q` quant.
2) If that is not possible, use `F32` as the data type.

## Avoiding memory spikes

On non-Metal systems, the tensors will be copied to the device and quantized in parallel. For CUDA devices, this can pose a problem because due to the asynchronous copies of the full precision tensors leading to later deallocation and a (although less than loading the entire model on the GPU) memory spike.

To solve this, for CUDA systems, you can set the `ISQ_LOW_MEMORY` environment variable to significantly reduce the remaining memory spike.

```
ISQ_LOW_MEMORY=1 cargo run --release --features cuda -- --isq Q4K -i plain -m microsoft/Phi-3-mini-128k-instruct -a phi3
```

## Python Example
```python
runner = Runner(
Expand Down
27 changes: 24 additions & 3 deletions mistralrs-core/src/pipeline/isq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::sync::{atomic::AtomicUsize, Arc};
use std::{
sync::{atomic::AtomicUsize, Arc},
time::Instant,
};

use candle_core::{
quantized::{GgmlDType, QMatMul, QTensor},
Expand All @@ -9,6 +12,9 @@ use tracing::{info, warn};

use crate::device_map::DeviceMapper;

#[cfg(feature = "cuda")]
const ISQ_THREAD_COUNT: usize = 4;

pub enum QuantizationBehaviour {
Quantize(GgmlDType),
Skip,
Expand Down Expand Up @@ -68,7 +74,8 @@ macro_rules! generate_isq {
$n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
QMatMul::QTensor(Arc::new(QTensor::quantize(&t, dtype).unwrap()))
}
}
};
$device.synchronize().unwrap();
}
};
}
Expand Down Expand Up @@ -101,8 +108,21 @@ pub trait IsqModel {
devices.push(device.clone());
}

let t_start = Instant::now();
#[cfg(not(feature = "metal"))]
{
#[cfg(feature = "cuda")]
{
let isq_low_mem = std::env::var("ISQ_LOW_MEMORY").is_ok();
if isq_low_mem {
rayon::ThreadPoolBuilder::new()
.num_threads(ISQ_THREAD_COUNT)
.build_global()
.expect("Failed to build global thread pool");
}
}
info!("Applying ISQ on {} threads.", rayon::current_num_threads());

use indicatif::ParallelProgressIterator;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
tensors
Expand All @@ -125,7 +145,8 @@ pub trait IsqModel {
generate_isq!(tensor, device, dtype, n_quantized)
});
}
info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors.");
let delta = Instant::now().duration_since(t_start).as_secs_f32();
info!("Applied in-situ quantization into {dtype:?} to {n_quantized:?} tensors out of {total_tensors} total tensors. Took {delta:.2}s", );

Ok(())
}
Expand Down

0 comments on commit 6648673

Please sign in to comment.