Skip to content

Commit

Permalink
adding reboot functionality (#378)
Browse files Browse the repository at this point in the history
* adding reboot functionality

* clippy

* addressing comments

* thread handler checking

* adding errors, removing extra logs

* simplifying error display impl

* fixing phi3v impl and the docs

* nit: capitalizing error messages
  • Loading branch information
gregszumel authored Jun 13, 2024
1 parent 07bf840 commit 1ae5bfe
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 41 deletions.
4 changes: 2 additions & 2 deletions docs/PHI3V.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down Expand Up @@ -215,4 +215,4 @@ print(res.usage)
```

- You can find an example of encoding the [image via base64 here](../examples/python/phi3v_base64.py).
- You can find an example of loading an [image locally here](../examples/python/phi3v_local_img.py).
- You can find an example of loading an [image locally here](../examples/python/phi3v_local_img.py).
4 changes: 2 additions & 2 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn run_bench(
logits_bias: None,
n_choices: 1,
};
let sender = mistralrs.get_sender();
let sender = mistralrs.get_sender().unwrap();
let (tx, mut rx) = channel(10_000);

let req = Request::Normal(NormalRequest {
Expand Down Expand Up @@ -221,7 +221,7 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
logits_bias: None,
n_choices: 1,
};
let sender = mistralrs.get_sender();
let sender = mistralrs.get_sender().unwrap();
let (tx, mut rx) = channel(10_000);

let req = Request::Normal(NormalRequest {
Expand Down
141 changes: 124 additions & 17 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ pub use engine::TERMINATE_ALL_NEXT_STEP;
pub use lora::Ordering;
use pipeline::ModelCategory;
pub use pipeline::Pipeline;
use pyo3::exceptions::PyValueError;
use std::{
cell::RefCell,
error::Error,
fs::OpenOptions,
io::Write,
sync::{atomic::AtomicBool, Arc, Mutex},
thread,
sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
thread::{self, JoinHandle},
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc::{channel, Sender};
Expand Down Expand Up @@ -76,11 +77,44 @@ pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
/// `Sender` and `Receiver` primitives to send and receive requests to the
/// engine.
pub struct MistralRs {
sender: Sender<Request>,
sender: RwLock<Sender<Request>>,
log: Option<String>,
id: String,
creation_time: u64,
next_request_id: Mutex<RefCell<usize>>,
reboot_state: RebootState,
engine_handler: RwLock<JoinHandle<()>>,
}

#[derive(Clone)]
struct RebootState {
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerMethod,
truncate_sequence: bool,
no_kv_cache: bool,
no_prefix_cache: bool,
prefix_cache_n: usize,
disable_eos_stop: bool,
}

#[derive(Debug)]
pub enum MistralRsError {
EnginePoisoned,
SenderPoisoned,
}

impl std::fmt::Display for MistralRsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self)
}
}

impl std::error::Error for MistralRsError {}

impl From<MistralRsError> for pyo3::PyErr {
fn from(value: MistralRsError) -> Self {
PyValueError::new_err(format!("{:?}", value))
}
}

/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
Expand Down Expand Up @@ -216,19 +250,22 @@ impl MistralRs {
let prefix_cache_n = prefix_cache_n.unwrap_or(16);
let disable_eos_stop = disable_eos_stop.unwrap_or(false);

let reboot_state = RebootState {
pipeline: pipeline.clone(),
method: method.clone(),
truncate_sequence,
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
};

let (tx, rx) = channel(10_000);

let this = Arc::new(Self {
sender: tx,
log,
id: pipeline.try_lock().unwrap().name(),
creation_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_secs(),
next_request_id: Mutex::new(RefCell::new(0)),
});
thread::spawn(move || {
let sender = RwLock::new(tx);
let id = pipeline.try_lock().unwrap().name();

let engine_handler = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let mut engine = Engine::new(
Expand All @@ -245,11 +282,81 @@ impl MistralRs {
});
});

this
Arc::new(Self {
sender,
log,
id,
creation_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_secs(),
next_request_id: Mutex::new(RefCell::new(0)),
reboot_state,
engine_handler: RwLock::new(engine_handler),
})
}

/// attempts to reboot the engine, if the sender (only way to communicate with
/// the engine) is closed
fn reboot_engine(&self) -> Result<(), MistralRsError> {
let (new_sender, rx) = channel(10_000);
let reboot_state = self.reboot_state.clone();
let mut sender_lock = self.sender.write().map_err(|_| {
tracing::warn!("Couldn't get write lock on the sender during reboot attempt");
MistralRsError::SenderPoisoned
})?;
let mut engine_lock = self.engine_handler.write().map_err(|_| {
tracing::warn!("Couldn't get write lock on the engine during reboot attempt");
MistralRsError::EnginePoisoned
})?;

if !engine_lock.is_finished() {
tracing::info!("Engine already running, returning ok");
Ok(())
} else {
// critical section. A panic here could lead to poisoned locks
let new_engine_handler = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let mut engine = Engine::new(
rx,
reboot_state.pipeline.clone(),
reboot_state.method,
reboot_state.truncate_sequence,
reboot_state.no_kv_cache,
reboot_state.no_prefix_cache,
reboot_state.prefix_cache_n,
reboot_state.disable_eos_stop,
);
engine.run().await;
});
});
*sender_lock = new_sender;
*engine_lock = new_engine_handler;
tracing::info!("Successfully rebooted engine and updated sender + engine handler");
Ok(())
}
}

pub fn get_sender(&self) -> Sender<Request> {
self.sender.clone()
fn engine_dead(&self) -> Result<bool, MistralRsError> {
match self.engine_handler.read() {
Ok(handler) => Ok(handler.is_finished()),
Err(_) => {
tracing::warn!("Couldn't get read lock on engine!");
Err(MistralRsError::EnginePoisoned)
}
}
}

pub fn get_sender(&self) -> Result<Sender<Request>, MistralRsError> {
if self.engine_dead()? {
tracing::warn!("Engine is dead, rebooting");
self.reboot_engine()?
}
match self.sender.read() {
Ok(sender) => Ok(sender.clone()),
Err(_) => Err(MistralRsError::SenderPoisoned),
}
}

pub fn get_id(&self) -> String {
Expand Down
11 changes: 11 additions & 0 deletions mistralrs-core/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ pub enum SchedulerMethod {
Fixed(UsizeBounded<1, { usize::MAX }, false>),
}

impl Clone for SchedulerMethod {
fn clone(&self) -> Self {
match self {
SchedulerMethod::Fixed(val) => {
let v = **val;
SchedulerMethod::Fixed(v.try_into().unwrap())
}
}
}
}

pub struct BucketedSeqs<Backer: FcfsBacker> {
running: Vec<Sequence>,
waiting: Backer,
Expand Down
12 changes: 8 additions & 4 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ impl Runner {
});

MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}"));
let sender = self.runner.get_sender();
let sender = self.runner.get_sender()?;
sender.blocking_send(model_request).unwrap();

if request.stream {
Expand Down Expand Up @@ -764,7 +764,7 @@ impl Runner {
});

MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}"));
let sender = self.runner.get_sender();
let sender = self.runner.get_sender()?;
sender.blocking_send(model_request).unwrap();
let response = rx.blocking_recv().unwrap();

Expand All @@ -788,14 +788,18 @@ impl Runner {
fn send_re_isq(&self, dtype: String) -> PyResult<()> {
let request =
_Request::ReIsq(parse_isq(&dtype).map_err(|e| PyValueError::new_err(e.to_string()))?);
self.runner.get_sender().blocking_send(request).unwrap();
self.runner.get_sender()?.blocking_send(request).unwrap();
Ok(())
}

/// Send a request to make the specified adapters the active adapters for the model.
fn activate_adapters(&self, adapter_names: Vec<String>) {
let request = _Request::ActivateAdapters(adapter_names);
self.runner.get_sender().blocking_send(request).unwrap();
self.runner
.get_sender()
.unwrap()
.blocking_send(request)
.unwrap();
}
}

Expand Down
2 changes: 1 addition & 1 deletion mistralrs-server/src/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ pub async fn chatcompletions(
return ChatCompletionResponder::InternalError(e.into());
}
};
let sender = state.get_sender();
let sender = state.get_sender().unwrap();

if let Err(e) = sender.send(request).await {
let e = anyhow::Error::msg(e.to_string());
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-server/src/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ pub async fn completions(
);
}
let request = parse_request(oairequest, state.clone(), tx);
let sender = state.get_sender();
let sender = state.get_sender().unwrap();

if let Err(e) = sender.send(request).await {
let e = anyhow::Error::msg(e.to_string());
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
Lazy::new(|| Mutex::new(&exit_handler));

pub async fn interactive_mode(mistralrs: Arc<MistralRs>) {
let sender = mistralrs.get_sender();
let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();

let sampling_params = SamplingParams {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async fn activate_adapters(
let repr = format!("Adapter activation: {:?}", request.adapter_names);
MistralRs::maybe_log_request(state.clone(), repr.clone());
let request = Request::ActivateAdapters(request.adapter_names);
state.get_sender().send(request).await.unwrap();
state.get_sender().unwrap().send(request).await.unwrap();
repr
}

Expand All @@ -188,7 +188,7 @@ async fn re_isq(
let repr = format!("Re ISQ: {:?}", request.ggml_type);
MistralRs::maybe_log_request(state.clone(), repr.clone());
let request = Request::ReIsq(parse_isq(&request.ggml_type)?);
state.get_sender().send(request).await.unwrap();
state.get_sender().unwrap().send(request).await.unwrap();
Ok(repr)
}

Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/gguf_locally/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/grammar/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/isq/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs/examples/lora/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ fn main() -> anyhow::Result<()> {

// Example: Make adapter_3 the active adapter
mistralrs
.get_sender()
.get_sender()?
.blocking_send(Request::ActivateAdapters(vec!["adapter_3".to_string()]))?;
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/lora_activation/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> {
adapters: Some(vec!["adapter_2".to_string()]),
});

mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/phi3v/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/quantized/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/simple/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
mistralrs.get_sender()?.blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Expand Down
Loading

0 comments on commit 1ae5bfe

Please sign in to comment.