diff --git a/mistralrs-bench/Cargo.toml b/mistralrs-bench/Cargo.toml index 65dedb124..73b351d4d 100644 --- a/mistralrs-bench/Cargo.toml +++ b/mistralrs-bench/Cargo.toml @@ -21,6 +21,7 @@ mistralrs-core = { version = "0.1.1", path = "../mistralrs-core" } tracing.workspace = true tracing-subscriber.workspace = true either.workspace = true +tokio.workspace = true cli-table = "0.4.7" [features] diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index cd83bbb88..73eef8756 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -6,8 +6,9 @@ use mistralrs_core::{ ModelSelected, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, Usage, }; +use std::fmt::Display; use std::sync::Arc; -use std::{fmt::Display, sync::mpsc::channel}; +use tokio::sync::mpsc::channel; use tracing::info; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; @@ -65,7 +66,7 @@ fn run_bench( n_choices: 1, }; let sender = mistralrs.get_sender(); - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let req = Request { id: mistralrs.next_request_id(), @@ -82,11 +83,13 @@ fn run_bench( for _ in 0..repetitions { for _ in 0..concurrency { - sender.send(req.clone()).expect("Expected receiver."); + sender + .blocking_send(req.clone()) + .expect("Expected receiver."); } for _ in 0..concurrency { - match rx.recv() { - Ok(r) => match r { + match rx.blocking_recv() { + Some(r) => match r { Response::InternalError(e) => { unreachable!("Got an internal error: {e:?}"); } @@ -105,7 +108,7 @@ fn run_bench( usages.push(res.usage); } }, - Err(e) => unreachable!("Expected a Done response, got: {:?}", e), + None => unreachable!("Expected a Done response, got None",), } } } diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 25d0ac8cd..4a63159cd 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -1,11 +1,10 @@ use std::{ - cell::RefCell, collections::{HashMap, VecDeque}, iter::zip, - rc::Rc, - sync::{mpsc::Receiver, Arc, Mutex}, + sync::{Arc, Mutex}, time::{Instant, SystemTime, UNIX_EPOCH}, }; +use tokio::sync::mpsc::Receiver; use crate::{ aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx, toktree::TokTrie}, @@ -90,7 +89,7 @@ impl Engine { let mut last_completion_ids: Vec = vec![]; 'lp: loop { while let Ok(request) = self.rx.try_recv() { - self.add_request(request); + self.add_request(request).await; } let run_start = Instant::now(); let mut scheduled = self.scheduler.schedule(); @@ -104,26 +103,31 @@ impl Engine { let current_completion_ids: Vec = scheduled.completion.iter().map(|seq| *seq.id()).collect(); let logits = { - let mut pipeline = get_mut_arcmutex!(self.pipeline); - // Run the completion seqs - // Run the completion seqs - if !self.no_kv_cache && last_completion_ids != current_completion_ids { - Self::clone_in_cache(&mut *pipeline, &mut scheduled.completion); - } - let logits = pipeline.forward(&scheduled.completion, false); + let logits = { + let mut pipeline = get_mut_arcmutex!(self.pipeline); + // Run the completion seqs + // Run the completion seqs + if !self.no_kv_cache && last_completion_ids != current_completion_ids { + Self::clone_in_cache(&mut *pipeline, &mut scheduled.completion); + } + pipeline.forward(&scheduled.completion, false) + }; let logits = handle_pipeline_forward_error!( "completion", logits, &mut scheduled.completion, - pipeline, + self.pipeline, 'lp, self.prefix_cacher ); - if !self.no_kv_cache { - Self::clone_out_cache(&mut *pipeline, &mut scheduled.completion); - } else { - Self::set_none_cache(&mut *pipeline); + { + let mut pipeline = get_mut_arcmutex!(self.pipeline); + if !self.no_kv_cache { + Self::clone_out_cache(&mut *pipeline, &mut scheduled.completion); + } else { + Self::set_none_cache(&mut *pipeline); + } } logits }; @@ -141,7 +145,7 @@ impl Engine { "sampling", sampled_result, &mut scheduled.completion, - get_mut_arcmutex!(self.pipeline), + self.pipeline, 'lp, self.prefix_cacher ); @@ -154,23 +158,25 @@ impl Engine { // Run the prompt seqs Self::set_none_cache(&mut *pipeline); - let logits = pipeline.forward(&scheduled.prompt, true); - let logits = handle_pipeline_forward_error!( - "prompt", - logits, - &mut scheduled.prompt, - pipeline, - 'lp, - self.prefix_cacher - ); + pipeline.forward(&scheduled.prompt, true) + }; + let logits = handle_pipeline_forward_error!( + "prompt", + logits, + &mut scheduled.prompt, + self.pipeline, + 'lp, + self.prefix_cacher + ); + { + let mut pipeline = get_mut_arcmutex!(self.pipeline); if !self.no_kv_cache { Self::clone_out_cache(&mut *pipeline, &mut scheduled.prompt); } else { Self::set_none_cache(&mut *pipeline); } - logits - }; + } let sampled_result = Self::sample_seqs( self.pipeline.clone(), @@ -185,7 +191,7 @@ impl Engine { "sampling", sampled_result, &mut scheduled.prompt, - get_mut_arcmutex!(self.pipeline), + self.pipeline, 'lp, self.prefix_cacher ); @@ -235,8 +241,8 @@ impl Engine { && self.scheduler.waiting_len() == 0 { // If there is nothing to do, sleep until a request comes in - if let Ok(request) = self.rx.recv() { - self.add_request(request); + if let Some(request) = self.rx.recv().await { + self.add_request(request).await; } } } @@ -399,6 +405,7 @@ impl Engine { if seq .get_mut_group() .maybe_send_streaming_response(seq, pipeline_name.clone()) + .await .is_err() { // If we can't send the response, cancel the sequence @@ -408,25 +415,29 @@ impl Engine { } } } else if let Some(reason) = is_done { - let mut pipeline = get_mut_arcmutex!(pipeline); - Self::finish_seq(&mut *pipeline, seq, reason, prefix_cacher)?; - pipeline.reset_non_granular_state(); + Self::finish_seq(pipeline.clone(), seq, reason, prefix_cacher).await?; + get_mut_arcmutex!(pipeline).reset_non_granular_state(); } } Ok(()) } - fn finish_seq( - pipeline: &mut dyn Pipeline, + async fn finish_seq( + pipeline: Arc>, seq: &mut Sequence, reason: StopReason, prefix_cacher: &mut PrefixCacheManager, ) -> Result<()> { seq.set_state(SequenceState::Done(reason)); + let (tokenizer, pipeline_name) = { + let pipeline = get_mut_arcmutex!(pipeline); + let pipeline_name = pipeline.name(); + let tokenizer = pipeline.tokenizer(); + (tokenizer, pipeline_name) + }; let logprobs = if seq.return_logprobs() { - let tokenizer = pipeline.tokenizer().clone(); let mut logprobs = Vec::new(); for logprob in seq.logprobs() { let resp_logprob = ResponseLogprob { @@ -488,31 +499,37 @@ impl Engine { let group = seq.get_mut_group(); if group.is_chat { - group.maybe_send_done_response( - ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ); + group + .maybe_send_done_response( + ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: pipeline_name, + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ) + .await + .map_err(candle_core::Error::msg)?; } else { - group.maybe_send_completion_done_response( - CompletionResponse { - id: seq.id().to_string(), - choices: group.get_completion_choices().to_vec(), - created: seq.creation_time(), - model: pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "text_completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ); + group + .maybe_send_completion_done_response( + CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: pipeline_name, + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ) + .await + .map_err(candle_core::Error::msg)?; } Ok(()) @@ -667,7 +684,7 @@ impl Engine { } } - fn add_request(&mut self, request: Request) { + async fn add_request(&mut self, request: Request) { let is_chat = matches!(request.messages, RequestMessage::Chat(_)); let echo_prompt = matches!( request.messages, @@ -690,17 +707,15 @@ impl Engine { .response .send(Response::ValidationError( "Received messages for a model which does not have a chat template. Either use a different model or pass a single string as the prompt".into(), - )).expect("Expected receiver."); + )).await.expect("Expected receiver."); return; } let mut force_tokens = None; let formatted_prompt = match request.messages { RequestMessage::Chat(messages) => { - handle_seq_error!( - get_mut_arcmutex!(self.pipeline).apply_chat_template(messages, true), - request.response - ) + let template = get_mut_arcmutex!(self.pipeline).apply_chat_template(messages, true); + handle_seq_error!(template, request.response) } RequestMessage::Completion { text, .. } => text, RequestMessage::CompletionTokens(it) => { @@ -718,15 +733,16 @@ impl Engine { .send(Response::ValidationError( "Received an empty prompt.".into(), )) + .await .expect("Expected receiver."); return; } let mut prompt = match force_tokens { Some(tks) => tks, - None => handle_seq_error!( - get_mut_arcmutex!(self.pipeline).tokenize_prompt(&formatted_prompt), - request.response - ), + None => { + let prompt = get_mut_arcmutex!(self.pipeline).tokenize_prompt(&formatted_prompt); + handle_seq_error!(prompt, request.response) + } }; if prompt.len() > get_mut_arcmutex!(self.pipeline).get_max_seq_len() { @@ -735,7 +751,7 @@ impl Engine { .response .send(Response::ValidationError( format!("Prompt sequence length is greater than {}, perhaps consider using `truncate_sequence`?", get_mut_arcmutex!(self.pipeline).get_max_seq_len()).into(), - )).expect("Expected receiver."); + )).await.expect("Expected receiver."); return; } else { let prompt_len = prompt.len(); @@ -770,8 +786,10 @@ impl Engine { let (stop_toks, stop_strings) = match request.sampling_params.stop_toks { None => (vec![], vec![]), Some(StopTokens::Ids(ref i)) => { - let pipeline = get_mut_arcmutex!(self.pipeline); - let tok_trie = pipeline.tok_trie(); + let tok_trie = { + let pipeline = get_mut_arcmutex!(self.pipeline); + pipeline.tok_trie() + }; for id in i { // We can't use ` ` (space) as a stop token because other tokens like ` moon` start with a space. if tok_trie.has_extensions(tok_trie.token(*id)) { @@ -780,7 +798,7 @@ impl Engine { .send(Response::ValidationError( format!("Stop token {:?} is also a prefix of other tokens and cannot be used as a stop token.", tok_trie.token_str(*id)).into(), )) - .expect("Expected receiver."); + .await .expect("Expected receiver."); return; } } @@ -791,9 +809,12 @@ impl Engine { let mut stop_toks = Vec::new(); let mut stop_strings: Vec = Vec::new(); - let pipeline = get_mut_arcmutex!(self.pipeline); - let tok_trie = pipeline.tok_trie(); - let tokenizer = pipeline.tokenizer(); + let (tok_trie, tokenizer) = { + let pipeline = get_mut_arcmutex!(self.pipeline); + let tok_trie = pipeline.tok_trie(); + let tokenizer = pipeline.tokenizer(); + (tok_trie, tokenizer) + }; for stop_txt in s { let encoded = tokenizer.encode(stop_txt.to_string(), false); @@ -815,7 +836,8 @@ impl Engine { (stop_toks, stop_strings) } }; - let group = Rc::new(RefCell::new(SequenceGroup::new( + + let group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new( request.sampling_params.n_choices, request.is_streaming, is_chat, @@ -833,6 +855,7 @@ impl Engine { .send(Response::ValidationError( format!("Failed creation of logits bias. {}", err).into(), )) + .await .expect("Expected receiver."); return; } @@ -857,6 +880,7 @@ impl Engine { .send(Response::ValidationError( format!("Invalid grammar. {}", err).into(), )) + .await .expect("Expected receiver."); return; } @@ -868,6 +892,7 @@ impl Engine { .send(Response::ValidationError( "Number of choices must be greater than 0.".into(), )) + .await .expect("Expected receiver."); return; } diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 347a5e058..61177832d 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -5,13 +5,11 @@ use std::{ error::Error, fs::OpenOptions, io::Write, - sync::{ - mpsc::{channel, Sender}, - Arc, Mutex, - }, + sync::{Arc, Mutex}, thread, time::{SystemTime, UNIX_EPOCH}, }; +use tokio::sync::mpsc::{channel, Sender}; use candle_core::quantized::GgmlDType; use engine::Engine; @@ -147,8 +145,8 @@ impl MistralRs { let prefix_cache_n = prefix_cache_n.unwrap_or(16); let disable_eos_stop = disable_eos_stop.unwrap_or(false); - let (tx, rx) = channel(); - let (isq_tx, isq_rx) = channel(); + let (tx, rx) = channel(10_000); + let (isq_tx, isq_rx) = channel(10_000); let this = Arc::new(Self { sender: tx, @@ -189,7 +187,9 @@ impl MistralRs { /// Send a request to re-ISQ the model. If the model was loaded as GGUF or GGML /// then nothing will happen. pub fn send_re_isq(&self, dtype: GgmlDType) { - self.sender_isq.send(dtype).expect("Engine is not present.") + self.sender_isq + .blocking_send(dtype) + .expect("Engine is not present.") } pub fn get_id(&self) -> String { diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 7b88a7713..38b857d55 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -1,7 +1,8 @@ use indexmap::IndexMap; use crate::{response::Response, sampler::SamplingParams}; -use std::{fmt::Debug, sync::mpsc::Sender}; +use std::fmt::Debug; +use tokio::sync::mpsc::Sender; #[derive(Clone)] /// Control the constraint with Regex or Yacc. diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index a5d9d5e55..9526229e2 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -1,10 +1,12 @@ use std::{ - cell::{Cell, RefCell, RefMut}, - rc::Rc, - sync::mpsc::{SendError, Sender}, + cell::Cell, sync::Arc, time::{SystemTime, UNIX_EPOCH}, }; +use tokio::sync::{ + mpsc::{error::SendError, Sender}, + Mutex, MutexGuard, +}; use crate::{ aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, @@ -94,7 +96,7 @@ pub struct Sequence { // GPU things pub prompt_tok_per_sec: f32, pub prompt_timestamp: Option, - group: Rc>, + group: Arc>, state: Cell, } impl Sequence { @@ -111,7 +113,7 @@ impl Sequence { max_len: Option, return_logprobs: bool, is_xlora: bool, - group: Rc>, + group: Arc>, response_index: usize, creation_time: u64, recognizer: SequenceRecognizer, @@ -399,7 +401,7 @@ impl Sequence { self.response_index } - pub fn get_mut_group(&self) -> RefMut<'_, SequenceGroup> { + pub fn get_mut_group(&self) -> MutexGuard<'_, SequenceGroup> { get_mut_group!(self) } @@ -476,19 +478,19 @@ impl SequenceGroup { } } - pub fn maybe_send_done_response( + pub async fn maybe_send_done_response( &self, response: ChatCompletionResponse, sender: Sender, - ) { + ) -> Result<(), SendError> { if self.choices.len() == self.n_choices { - sender - .send(Response::Done(response)) - .expect("Expected receiver."); + sender.send(Response::Done(response)).await?; } + + Ok(()) } - pub fn maybe_send_streaming_response( + pub async fn maybe_send_streaming_response( &mut self, seq: &Sequence, model: String, @@ -506,20 +508,20 @@ impl SequenceGroup { model: model.clone(), system_fingerprint: SYSTEM_FINGERPRINT.to_string(), object: "chat.completion.chunk".to_string(), - }))?; + })) + .await?; } Ok(()) } - pub fn maybe_send_completion_done_response( + pub async fn maybe_send_completion_done_response( &self, response: CompletionResponse, sender: Sender, - ) { + ) -> Result<(), Box>> { if self.completion_choices.len() == self.n_choices { - sender - .send(Response::CompletionDone(response)) - .expect("Expected receiver."); + sender.send(Response::CompletionDone(response)).await?; } + Ok(()) } } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 272d398de..a8b49f6f3 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -21,6 +21,7 @@ macro_rules! handle_seq_error { use $crate::response::Response; $response .send(Response::InternalError(e.into())) + .await .expect("Expected receiver."); return; } @@ -37,6 +38,7 @@ macro_rules! handle_seq_error_ok { use $crate::response::Response; $response .send(Response::InternalError(e.into())) + .await .expect("Expected receiver."); return Ok(()); } @@ -44,24 +46,6 @@ macro_rules! handle_seq_error_ok { }; } -#[macro_export] -macro_rules! handle_seq_error_stateaware { - ($fallible:expr, $seq:expr) => { - match $fallible { - Ok(v) => v, - Err(e) => { - use $crate::response::Response; - use $crate::sequence::SequenceState; - $seq.responder() - .send(Response::InternalError(e.into())) - .expect("Expected receiver."); - $seq.set_state(SequenceState::Error); - return; - } - } - }; -} - #[macro_export] macro_rules! handle_seq_error_stateaware_ok { ($fallible:expr, $seq:expr) => { @@ -72,6 +56,7 @@ macro_rules! handle_seq_error_stateaware_ok { use $crate::sequence::SequenceState; $seq.responder() .send(Response::InternalError(e.into())) + .await .expect("Expected receiver."); $seq.set_state(SequenceState::Error); return Ok(()); @@ -86,7 +71,12 @@ macro_rules! handle_pipeline_forward_error { match $fallible { Ok(v) => v, Err(e) => { - let mut pipeline = $pipeline; + let (tokenizer, pipeline_name) = { + let pipeline = get_mut_arcmutex!($pipeline); + let pipeline_name = pipeline.name(); + let tokenizer = pipeline.tokenizer(); + (tokenizer, pipeline_name) + }; use $crate::response::Response; use $crate::sequence::SequenceState; use $crate::Engine; @@ -95,8 +85,7 @@ macro_rules! handle_pipeline_forward_error { error!("{} - Model failed with error: {:?}", $stage, &e); for seq in $seq_slice.iter_mut() { // Step 1: Add all choices to groups - let res = match pipeline - .tokenizer() + let res = match tokenizer .decode(&seq.get_toks()[seq.prompt_tokens()..], false) { Ok(v) => v, @@ -133,7 +122,7 @@ macro_rules! handle_pipeline_forward_error { id: seq.id().to_string(), choices: group.get_choices().to_vec(), created: seq.creation_time(), - model: pipeline.name(), + model: pipeline_name.clone(), system_fingerprint: SYSTEM_FINGERPRINT.to_string(), object: "chat.completion".to_string(), usage: group.get_usage(), @@ -144,13 +133,14 @@ macro_rules! handle_pipeline_forward_error { e.to_string(), partial_completion_response )) + .await .unwrap(); } else { let partial_completion_response = CompletionResponse { id: seq.id().to_string(), choices: group.get_completion_choices().to_vec(), created: seq.creation_time(), - model: pipeline.name(), + model: pipeline_name.clone(), system_fingerprint: SYSTEM_FINGERPRINT.to_string(), object: "text_completion".to_string(), usage: group.get_usage(), @@ -161,6 +151,7 @@ macro_rules! handle_pipeline_forward_error { e.to_string(), partial_completion_response )) + .await .unwrap(); } } @@ -169,7 +160,7 @@ macro_rules! handle_pipeline_forward_error { seq.set_state(SequenceState::Error); } - Engine::set_none_cache(&mut *pipeline); + Engine::set_none_cache(&mut *get_mut_arcmutex!($pipeline)); $prefix_cacher.evict_all_to_cpu().unwrap(); continue $label; @@ -182,7 +173,7 @@ macro_rules! handle_pipeline_forward_error { macro_rules! get_mut_group { ($this:expr) => { loop { - if let Ok(inner) = $this.group.try_borrow_mut() { + if let Ok(inner) = $this.group.try_lock() { break inner; } } diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index 3db491ee0..b7373d891 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -26,6 +26,7 @@ accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } either.workspace = true futures.workspace = true +tokio.workspace = true tracing-subscriber.workspace = true [features] diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 8a4a2cd54..b187734b8 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -9,9 +9,10 @@ use std::{ collections::HashMap, fmt::Debug, str::FromStr, - sync::{mpsc::channel, Arc, Mutex}, + sync::{Arc, Mutex}, }; use stream::ChatCompletionStreamer; +use tokio::sync::mpsc::channel; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use candle_core::Device; @@ -415,7 +416,7 @@ impl Runner { &mut self, request: Py, ) -> PyResult> { - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); Python::with_gil(|py| { let request = request.bind(py).borrow(); let stop_toks = request @@ -497,12 +498,12 @@ impl Runner { MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); let sender = self.runner.get_sender(); - sender.send(model_request).unwrap(); + sender.blocking_send(model_request).unwrap(); if request.stream { Ok(Either::Right(ChatCompletionStreamer::from_rx(rx))) } else { - let response = rx.recv().unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::ValidationError(e) | Response::InternalError(e) => { @@ -523,7 +524,7 @@ impl Runner { &mut self, request: Py, ) -> PyResult { - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); Python::with_gil(|py| { let request = request.bind(py).borrow(); let stop_toks = request @@ -585,8 +586,8 @@ impl Runner { MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); let sender = self.runner.get_sender(); - sender.send(model_request).unwrap(); - let response = rx.recv().unwrap(); + sender.blocking_send(model_request).unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::ValidationError(e) | Response::InternalError(e) => { diff --git a/mistralrs-pyo3/src/stream.rs b/mistralrs-pyo3/src/stream.rs index 62d29a7c6..eee128341 100644 --- a/mistralrs-pyo3/src/stream.rs +++ b/mistralrs-pyo3/src/stream.rs @@ -1,4 +1,4 @@ -use std::sync::mpsc::Receiver; +use tokio::sync::mpsc::Receiver; use mistralrs_core::{ChatCompletionChunkResponse, Response}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyRef, PyRefMut, PyResult}; @@ -24,8 +24,8 @@ impl ChatCompletionStreamer { if this.is_done { return None; } - match this.rx.recv() { - Ok(resp) => match resp { + match this.rx.blocking_recv() { + Some(resp) => match resp { Response::ModelError(msg, _) => Some(Err(PyValueError::new_err(msg.to_string()))), Response::ValidationError(e) => Some(Err(PyValueError::new_err(e.to_string()))), Response::InternalError(e) => Some(Err(PyValueError::new_err(e.to_string()))), @@ -39,7 +39,9 @@ impl ChatCompletionStreamer { Response::CompletionDone(_) => unreachable!(), Response::CompletionModelError(_, _) => unreachable!(), }, - Err(e) => Some(Err(PyValueError::new_err(e.to_string()))), + None => Some(Err(PyValueError::new_err( + "Received none in ChatCompletionStreamer".to_string(), + ))), } } } diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 0a41154f6..364f2a522 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -2,13 +2,11 @@ use std::{ env, error::Error, pin::Pin, - sync::{ - mpsc::{channel, Receiver, Sender}, - Arc, - }, + sync::Arc, task::{Context, Poll}, time::Duration, }; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use crate::openai::{ChatCompletionRequest, Grammar, StopTokens}; use anyhow::Result; @@ -217,11 +215,16 @@ pub async fn chatcompletions( State(state): State>, Json(oairequest): Json, ) -> ChatCompletionResponder { - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = parse_request(oairequest, state.clone(), tx); let is_streaming = request.is_streaming; let sender = state.get_sender(); - sender.send(request).unwrap(); + + if let Err(e) = sender.send(request).await { + let e = anyhow::Error::msg(e.to_string()); + MistralRs::maybe_log_error(state, &*e); + return ChatCompletionResponder::InternalError(e.into()); + } if is_streaming { let streamer = Streamer { @@ -242,7 +245,14 @@ pub async fn chatcompletions( ), ) } else { - let response = rx.recv().unwrap(); + let response = match rx.recv().await { + Some(response) => response, + None => { + let e = anyhow::Error::msg("No response received from the model."); + MistralRs::maybe_log_error(state, &*e); + return ChatCompletionResponder::InternalError(e.into()); + } + }; match response { Response::InternalError(e) => { diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index 38e37e6f4..cfc01c9de 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -1,10 +1,5 @@ -use std::{ - error::Error, - sync::{ - mpsc::{channel, Sender}, - Arc, - }, -}; +use std::{error::Error, sync::Arc}; +use tokio::sync::mpsc::{channel, Sender}; use crate::openai::{CompletionRequest, Grammar, StopTokens}; use axum::{ @@ -151,7 +146,7 @@ pub async fn completions( State(state): State>, Json(oairequest): Json, ) -> CompletionResponder { - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = parse_request(oairequest, state.clone(), tx); let is_streaming = request.is_streaming; let sender = state.get_sender(); @@ -168,9 +163,20 @@ pub async fn completions( ); } - sender.send(request).unwrap(); + if let Err(e) = sender.send(request).await { + let e = anyhow::Error::msg(e.to_string()); + MistralRs::maybe_log_error(state, &*e); + return CompletionResponder::InternalError(e.into()); + } - let response = rx.recv().unwrap(); + let response = match rx.recv().await { + Some(response) => response, + None => { + let e = anyhow::Error::msg("No response received from the model."); + MistralRs::maybe_log_error(state, &*e); + return CompletionResponder::InternalError(e.into()); + } + }; match response { Response::InternalError(e) => { diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 58f10168b..0784677aa 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -2,11 +2,12 @@ use indexmap::IndexMap; use mistralrs_core::{Constraint, MistralRs, Request, RequestMessage, Response, SamplingParams}; use std::{ io::{self, Write}, - sync::{mpsc::channel, Arc}, + sync::Arc, }; +use tokio::sync::mpsc::channel; use tracing::{error, info}; -pub fn interactive_mode(mistralrs: Arc) { +pub async fn interactive_mode(mistralrs: Arc) { let sender = mistralrs.get_sender(); let mut messages = Vec::new(); @@ -35,7 +36,7 @@ pub fn interactive_mode(mistralrs: Arc) { user_message.insert("content".to_string(), prompt); messages.push(user_message); - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let req = Request { id: mistralrs.next_request_id(), messages: RequestMessage::Chat(messages.clone()), @@ -46,41 +47,39 @@ pub fn interactive_mode(mistralrs: Arc) { constraint: Constraint::None, suffix: None, }; - sender.send(req).unwrap(); + sender.send(req).await.unwrap(); let mut assistant_output = String::new(); - loop { - let resp = rx.try_recv(); - if let Ok(resp) = resp { - match resp { - Response::Chunk(chunk) => { - let choice = &chunk.choices[0]; - assistant_output.push_str(&choice.delta.content); - print!("{}", choice.delta.content); - io::stdout().flush().unwrap(); - if choice.finish_reason.is_some() { - if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") { - print!("..."); - } - break; + + while let Some(resp) = rx.recv().await { + match resp { + Response::Chunk(chunk) => { + let choice = &chunk.choices[0]; + assistant_output.push_str(&choice.delta.content); + print!("{}", choice.delta.content); + io::stdout().flush().unwrap(); + if choice.finish_reason.is_some() { + if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") { + print!("..."); } + break; } - Response::InternalError(e) => { - error!("Got an internal error: {e:?}"); - break 'outer; - } - Response::ModelError(e, resp) => { - error!("Got a model error: {e:?}, response: {resp:?}"); - break 'outer; - } - Response::ValidationError(e) => { - error!("Got a validation error: {e:?}"); - break 'outer; - } - Response::Done(_) => unreachable!(), - Response::CompletionDone(_) => unreachable!(), - Response::CompletionModelError(_, _) => unreachable!(), } + Response::InternalError(e) => { + error!("Got an internal error: {e:?}"); + break 'outer; + } + Response::ModelError(e, resp) => { + error!("Got a model error: {e:?}, response: {resp:?}"); + break 'outer; + } + Response::ValidationError(e) => { + error!("Got a validation error: {e:?}"); + break 'outer; + } + Response::Done(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), } } let mut assistant_message = IndexMap::new(); diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 3848483d2..225c91710 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -255,7 +255,7 @@ async fn main() -> Result<()> { .build(); if args.interactive_mode { - interactive_mode(mistralrs); + interactive_mode(mistralrs).await; return Ok(()); } diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index 48c058e7c..219bd471e 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -14,6 +14,7 @@ homepage.workspace = true [dependencies] mistralrs-core = { version = "0.1.1", path = "../mistralrs-core" } anyhow.workspace = true +tokio.workspace = true candle-core.workspace = true [features] diff --git a/mistralrs/examples/grammar/main.rs b/mistralrs/examples/grammar/main.rs index e95924210..bc7a3997b 100644 --- a/mistralrs/examples/grammar/main.rs +++ b/mistralrs/examples/grammar/main.rs @@ -1,4 +1,5 @@ -use std::sync::{mpsc::channel, Arc}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; use candle_core::Device; use mistralrs::{ @@ -36,7 +37,7 @@ fn setup() -> anyhow::Result> { fn main() -> anyhow::Result<()> { let mistralrs = setup()?; - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = Request { messages: RequestMessage::Completion { text: "I like to code in the following language: ".to_string(), @@ -51,9 +52,9 @@ fn main() -> anyhow::Result<()> { constraint: Constraint::Regex("(- [^\n]*\n)+(- [^\n]*)(\n\n)?".to_string()), // Bullet list regex suffix: None, }; - mistralrs.get_sender().send(request)?; + mistralrs.get_sender().blocking_send(request)?; - let response = rx.recv().unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text), _ => unreachable!(), diff --git a/mistralrs/examples/isq/main.rs b/mistralrs/examples/isq/main.rs index 6b99f2143..40a864a93 100644 --- a/mistralrs/examples/isq/main.rs +++ b/mistralrs/examples/isq/main.rs @@ -1,4 +1,5 @@ -use std::sync::{mpsc::channel, Arc}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; use candle_core::{quantized::GgmlDType, Device}; use mistralrs::{ @@ -36,7 +37,7 @@ fn setup() -> anyhow::Result> { fn main() -> anyhow::Result<()> { let mistralrs = setup()?; - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = Request { messages: RequestMessage::Completion { text: "Hello! My name is ".to_string(), @@ -51,9 +52,9 @@ fn main() -> anyhow::Result<()> { constraint: Constraint::None, suffix: None, }; - mistralrs.get_sender().send(request)?; + mistralrs.get_sender().blocking_send(request)?; - let response = rx.recv().unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text), _ => unreachable!(), diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index 955fa30d4..804238fad 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -1,4 +1,5 @@ -use std::sync::{mpsc::channel, Arc}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; use candle_core::Device; use mistralrs::{ @@ -35,7 +36,7 @@ fn setup() -> anyhow::Result> { fn main() -> anyhow::Result<()> { let mistralrs = setup()?; - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = Request { messages: RequestMessage::Completion { text: "Hello! My name is ".to_string(), @@ -50,9 +51,9 @@ fn main() -> anyhow::Result<()> { constraint: Constraint::None, suffix: None, }; - mistralrs.get_sender().send(request)?; + mistralrs.get_sender().blocking_send(request)?; - let response = rx.recv().unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text), _ => unreachable!(), diff --git a/mistralrs/examples/simple/main.rs b/mistralrs/examples/simple/main.rs index 787cc4fc6..80c778f1a 100644 --- a/mistralrs/examples/simple/main.rs +++ b/mistralrs/examples/simple/main.rs @@ -1,4 +1,5 @@ -use std::sync::{mpsc::channel, Arc}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; use candle_core::Device; use mistralrs::{ @@ -36,7 +37,7 @@ fn setup() -> anyhow::Result> { fn main() -> anyhow::Result<()> { let mistralrs = setup()?; - let (tx, rx) = channel(); + let (tx, mut rx) = channel(10_000); let request = Request { messages: RequestMessage::Completion { text: "Hello! My name is ".to_string(), @@ -51,9 +52,9 @@ fn main() -> anyhow::Result<()> { constraint: Constraint::None, suffix: None, }; - mistralrs.get_sender().send(request)?; + mistralrs.get_sender().blocking_send(request)?; - let response = rx.recv().unwrap(); + let response = rx.blocking_recv().unwrap(); match response { Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text), _ => unreachable!(),