diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 59281e0e6..50341850f 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -89,6 +89,7 @@ pub use sampler::{ CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob, }; pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig}; +pub use sequence::StopReason; use serde::Serialize; use tokio::runtime::Runtime; use toml_selector::{TomlLoaderArgs, TomlSelector}; diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index cab5c6f95..1cb63f637 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -49,7 +49,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( role: "assistant".to_string(), }, index: seq.get_response_index(), - finish_reason: is_done.map(|x| x.to_string()), + finish_reason: is_done, logprobs: if seq.return_logprobs() { Some(crate::ResponseLogprob { token: delta, @@ -66,7 +66,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( crate::CompletionChunkChoice { text: delta.clone(), index: seq.get_response_index(), - finish_reason: is_done.map(|x| x.to_string()), + finish_reason: is_done, logprobs: if seq.return_logprobs() { Some(crate::ResponseLogprob { token: delta, @@ -162,6 +162,9 @@ pub(crate) async fn finish_or_add_toks_to_seq( crate::sequence::StopReason::GeneratedImage => { candle_core::bail!("Stop reason was `GeneratedImage`.") } + crate::sequence::StopReason::Error => { + candle_core::bail!("Stop reason was `Error`.") + } }; if seq.get_mut_group().is_chat { @@ -175,7 +178,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( tool_calls = calls; } let choice = crate::Choice { - finish_reason: reason.to_string(), + finish_reason: reason, index: seq.get_response_index(), message: crate::ResponseMessage { content: text_new, @@ -187,7 +190,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( seq.add_choice_to_group(choice); } else { let choice = crate::CompletionChoice { - finish_reason: reason.to_string(), + finish_reason: reason, index: seq.get_response_index(), text, logprobs: None, diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index 9e3eae441..af4381164 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -7,7 +7,7 @@ use std::{ use pyo3::{pyclass, pymethods}; use serde::Serialize; -use crate::{sampler::TopLogprob, tools::ToolCallResponse}; +use crate::{sampler::TopLogprob, sequence::StopReason, tools::ToolCallResponse}; pub const SYSTEM_FINGERPRINT: &str = "local"; @@ -69,45 +69,33 @@ pub struct Logprobs { generate_repr!(Logprobs); -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Chat completion choice. pub struct Choice { - pub finish_reason: String, + pub finish_reason: StopReason, pub index: usize, pub message: ResponseMessage, pub logprobs: Option, } -generate_repr!(Choice); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Chat completion streaming chunk choice. pub struct ChunkChoice { - pub finish_reason: Option, + pub finish_reason: Option, pub index: usize, pub delta: Delta, pub logprobs: Option, } -generate_repr!(ChunkChoice); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Chat completion streaming chunk choice. pub struct CompletionChunkChoice { pub text: String, pub index: usize, pub logprobs: Option, - pub finish_reason: Option, + pub finish_reason: Option, } -generate_repr!(CompletionChunkChoice); - #[cfg_attr(feature = "pyo3_macros", pyclass)] #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] @@ -126,8 +114,6 @@ pub struct Usage { generate_repr!(Usage); -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// An OpenAI compatible chat completion response. pub struct ChatCompletionResponse { @@ -140,10 +126,6 @@ pub struct ChatCompletionResponse { pub usage: Usage, } -generate_repr!(ChatCompletionResponse); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Chat completion streaming request chunk. pub struct ChatCompletionChunkResponse { @@ -155,23 +137,15 @@ pub struct ChatCompletionChunkResponse { pub object: String, } -generate_repr!(ChatCompletionChunkResponse); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Completion request choice. pub struct CompletionChoice { - pub finish_reason: String, + pub finish_reason: StopReason, pub index: usize, pub text: String, pub logprobs: Option<()>, } -generate_repr!(CompletionChoice); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// An OpenAI compatible completion response. pub struct CompletionResponse { @@ -184,10 +158,6 @@ pub struct CompletionResponse { pub usage: Usage, } -generate_repr!(CompletionResponse); - -#[cfg_attr(feature = "pyo3_macros", pyclass)] -#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] /// Completion request choice. pub struct CompletionChunkResponse { @@ -199,8 +169,6 @@ pub struct CompletionChunkResponse { pub object: String, } -generate_repr!(CompletionChunkResponse); - #[cfg_attr(feature = "pyo3_macros", pyclass)] #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Debug, Clone, Serialize)] diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index c962ae90f..0d09a4702 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use std::{ fmt::Display, sync::{Arc, RwLock}, @@ -27,7 +28,7 @@ use crate::{ use candle_core::Tensor; use regex_automata::util::primitives::StateID; -#[derive(Clone, Copy, PartialEq, Debug)] +#[derive(Clone, Copy, PartialEq, Debug, Serialize)] pub enum StopReason { Eos, StopTok(u32), @@ -39,6 +40,7 @@ pub enum StopReason { }, Canceled, GeneratedImage, + Error, } impl Display for StopReason { @@ -49,6 +51,7 @@ impl Display for StopReason { StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"), StopReason::Canceled => write!(f, "canceled"), StopReason::GeneratedImage => write!(f, "generated-image"), + StopReason::Error => write!(f, "stop"), } } } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 6b202470e..50f94bea6 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -108,7 +108,7 @@ macro_rules! handle_pipeline_forward_error { if seq.get_mut_group().is_chat { let choice = Choice { - finish_reason: "error".to_string(), + finish_reason: StopReason::Error, index: seq.get_response_index(), message: ResponseMessage { content: Some(res), @@ -120,7 +120,7 @@ macro_rules! handle_pipeline_forward_error { seq.add_choice_to_group(choice); } else { let choice = CompletionChoice { - finish_reason: "error".to_string(), + finish_reason: StopReason::Error, index: seq.get_response_index(), text: res, logprobs: None, diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 0eae3a3c6..61804f624 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType}; use either::Either; use indexmap::IndexMap; use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice}; +use response::{ChatCompletionResponse, Choice, CompletionChoice, CompletionResponse}; use std::{ cell::RefCell, collections::HashMap, @@ -18,20 +19,21 @@ use util::{PyApiErr, PyApiResult}; use candle_core::{Device, Result}; use mistralrs_core::{ - initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader, - ChatCompletionResponse, CompletionResponse, Constraint, DefaultSchedulerMethod, - DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams, DiffusionLoaderBuilder, - DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder, GGMLSpecificConfig, - GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse, ImageGenerationResponseFormat, - Loader, MemoryGpuConfig, MistralRs, MistralRsBuilder, NormalLoaderBuilder, NormalRequest, - NormalSpecificConfig, PagedAttentionConfig, Request as _Request, RequestMessage, Response, - ResponseOk, SamplingParams, SchedulerConfig, SpeculativeConfig, SpeculativeLoader, StopTokens, - TokenSource, Tool, Topology, VisionLoaderBuilder, VisionSpecificConfig, + initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader, Constraint, + DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams, + DiffusionLoaderBuilder, DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder, + GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse, + ImageGenerationResponseFormat, Loader, MemoryGpuConfig, MistralRs, MistralRsBuilder, + NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig, + Request as _Request, RequestMessage, Response, ResponseOk, SamplingParams, SchedulerConfig, + SpeculativeConfig, SpeculativeLoader, StopTokens, TokenSource, Tool, Topology, + VisionLoaderBuilder, VisionSpecificConfig, }; use pyo3::prelude::*; use std::fs::File; mod anymoe; mod requests; +mod response; mod stream; mod util; mod which; @@ -877,7 +879,24 @@ impl Runner { Response::ValidationError(e) | Response::InternalError(e) => { Err(PyApiErr::from(e.to_string())) } - Response::Done(response) => Ok(Either::Left(response)), + Response::Done(response) => Ok(Either::Left(ChatCompletionResponse { + id: response.id, + created: response.created, + model: response.model, + system_fingerprint: response.system_fingerprint, + object: response.object, + usage: response.usage, + choices: response + .choices + .into_iter() + .map(|choice| Choice { + finish_reason: choice.finish_reason.to_string(), + index: choice.index, + message: choice.message, + logprobs: choice.logprobs, + }) + .collect(), + })), Response::ModelError(msg, _) => Err(PyApiErr::from(msg.to_string())), Response::Chunk(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), @@ -996,7 +1015,24 @@ impl Runner { Response::ValidationError(e) | Response::InternalError(e) => { Err(PyApiErr::from(e.to_string())) } - Response::CompletionDone(response) => Ok(response), + Response::CompletionDone(response) => Ok(CompletionResponse { + id: response.id, + created: response.created, + model: response.model, + system_fingerprint: response.system_fingerprint, + object: response.object, + usage: response.usage, + choices: response + .choices + .into_iter() + .map(|choice| CompletionChoice { + finish_reason: choice.finish_reason.to_string(), + index: choice.index, + text: choice.text, + logprobs: choice.logprobs, + }) + .collect(), + }), Response::CompletionModelError(msg, _) => Err(PyApiErr::from(msg.to_string())), Response::Chunk(_) => unreachable!(), Response::Done(_) => unreachable!(), @@ -1093,13 +1129,13 @@ fn mistralrs(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/mistralrs-pyo3/src/response.rs b/mistralrs-pyo3/src/response.rs new file mode 100644 index 000000000..bf1858056 --- /dev/null +++ b/mistralrs-pyo3/src/response.rs @@ -0,0 +1,93 @@ +//! This is just a hack to make `finish_reason` a string! Normally it is an enum + +use mistralrs_core::{Delta, Logprobs, ResponseLogprob, ResponseMessage, Usage}; + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct Choice { + pub finish_reason: String, + pub index: usize, + pub message: ResponseMessage, + pub logprobs: Option, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct ChatCompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub object: String, + pub usage: Usage, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct CompletionChoice { + pub finish_reason: String, + pub index: usize, + pub text: String, + pub logprobs: Option<()>, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub object: String, + pub usage: Usage, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct CompletionChunkChoice { + pub text: String, + pub index: usize, + pub logprobs: Option, + pub finish_reason: Option, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct CompletionChunkResponse { + pub id: String, + pub choices: Vec, + pub created: u128, + pub model: String, + pub system_fingerprint: String, + pub object: String, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct ChunkChoice { + pub finish_reason: Option, + pub index: usize, + pub delta: Delta, + pub logprobs: Option, +} + +#[pyo3::pyclass] +#[pyo3(get_all)] +#[derive(Clone, Debug)] +pub struct ChatCompletionChunkResponse { + pub id: String, + pub choices: Vec, + pub created: u128, + pub model: String, + pub system_fingerprint: String, + pub object: String, +} diff --git a/mistralrs-pyo3/src/stream.rs b/mistralrs-pyo3/src/stream.rs index 79da6f402..17efb3b9a 100644 --- a/mistralrs-pyo3/src/stream.rs +++ b/mistralrs-pyo3/src/stream.rs @@ -1,6 +1,7 @@ use tokio::sync::mpsc::Receiver; -use mistralrs_core::{ChatCompletionChunkResponse, Response}; +use crate::response::{ChatCompletionChunkResponse, ChunkChoice}; +use mistralrs_core::Response; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyRef, PyRefMut, PyResult}; #[pyclass] @@ -33,7 +34,23 @@ impl ChatCompletionStreamer { if response.choices.iter().all(|x| x.finish_reason.is_some()) { this.is_done = true; } - Some(Ok(response)) + Some(Ok(ChatCompletionChunkResponse { + id: response.id, + created: response.created, + model: response.model, + system_fingerprint: response.system_fingerprint, + object: response.object, + choices: response + .choices + .into_iter() + .map(|choice| ChunkChoice { + finish_reason: choice.finish_reason.map(|r| r.to_string()), + index: choice.index, + delta: choice.delta, + logprobs: choice.logprobs, + }) + .collect::>(), + })) } Response::Done(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 3575f3fea..6c4ecb316 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -3,7 +3,7 @@ use indexmap::IndexMap; use mistralrs_core::{ Constraint, DiffusionGenerationParams, DrySamplingParams, ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest, Request, RequestMessage, Response, - ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP, + ResponseOk, SamplingParams, StopReason, TERMINATE_ALL_NEXT_STEP, }; use once_cell::sync::Lazy; use std::{ @@ -190,7 +190,10 @@ async fn text_interactive_mode(mistralrs: Arc, throughput: bool) { toks += 3usize; // NOTE: we send toks every 3. io::stdout().flush().unwrap(); if choice.finish_reason.is_some() { - if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") { + if matches!( + choice.finish_reason.as_ref().unwrap(), + StopReason::Length(_) + ) { print!("..."); } break; @@ -372,7 +375,10 @@ async fn vision_interactive_mode(mistralrs: Arc, throughput: bool) { toks += 3usize; // NOTE: we send toks every 3. io::stdout().flush().unwrap(); if choice.finish_reason.is_some() { - if matches!(choice.finish_reason.as_ref().unwrap().as_str(), "length") { + if matches!( + choice.finish_reason.as_ref().unwrap(), + StopReason::Length(_) + ) { print!("..."); } break;