Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return stop reason enum in responses #802

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub use sampler::{
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
};
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
pub use sequence::StopReason;
use serde::Serialize;
use tokio::runtime::Runtime;
use toml_selector::{TomlLoaderArgs, TomlSelector};
Expand Down
11 changes: 7 additions & 4 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
role: "assistant".to_string(),
},
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
finish_reason: is_done,
logprobs: if seq.return_logprobs() {
Some(crate::ResponseLogprob {
token: delta,
Expand All @@ -66,7 +66,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
crate::CompletionChunkChoice {
text: delta.clone(),
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
finish_reason: is_done,
logprobs: if seq.return_logprobs() {
Some(crate::ResponseLogprob {
token: delta,
Expand Down Expand Up @@ -162,6 +162,9 @@ pub(crate) async fn finish_or_add_toks_to_seq(
crate::sequence::StopReason::GeneratedImage => {
candle_core::bail!("Stop reason was `GeneratedImage`.")
}
crate::sequence::StopReason::Error => {
candle_core::bail!("Stop reason was `Error`.")
}
};

if seq.get_mut_group().is_chat {
Expand All @@ -175,7 +178,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
tool_calls = calls;
}
let choice = crate::Choice {
finish_reason: reason.to_string(),
finish_reason: reason,
index: seq.get_response_index(),
message: crate::ResponseMessage {
content: text_new,
Expand All @@ -187,7 +190,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
seq.add_choice_to_group(choice);
} else {
let choice = crate::CompletionChoice {
finish_reason: reason.to_string(),
finish_reason: reason,
index: seq.get_response_index(),
text,
logprobs: None,
Expand Down
42 changes: 5 additions & 37 deletions mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use pyo3::{pyclass, pymethods};
use serde::Serialize;

use crate::{sampler::TopLogprob, tools::ToolCallResponse};
use crate::{sampler::TopLogprob, sequence::StopReason, tools::ToolCallResponse};

pub const SYSTEM_FINGERPRINT: &str = "local";

Expand Down Expand Up @@ -69,45 +69,33 @@ pub struct Logprobs {

generate_repr!(Logprobs);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion choice.
pub struct Choice {
pub finish_reason: String,
pub finish_reason: StopReason,
pub index: usize,
pub message: ResponseMessage,
pub logprobs: Option<Logprobs>,
}

generate_repr!(Choice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming chunk choice.
pub struct ChunkChoice {
pub finish_reason: Option<String>,
pub finish_reason: Option<StopReason>,
pub index: usize,
pub delta: Delta,
pub logprobs: Option<ResponseLogprob>,
}

generate_repr!(ChunkChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming chunk choice.
pub struct CompletionChunkChoice {
pub text: String,
pub index: usize,
pub logprobs: Option<ResponseLogprob>,
pub finish_reason: Option<String>,
pub finish_reason: Option<StopReason>,
}

generate_repr!(CompletionChunkChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
Expand All @@ -126,8 +114,6 @@ pub struct Usage {

generate_repr!(Usage);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// An OpenAI compatible chat completion response.
pub struct ChatCompletionResponse {
Expand All @@ -140,10 +126,6 @@ pub struct ChatCompletionResponse {
pub usage: Usage,
}

generate_repr!(ChatCompletionResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming request chunk.
pub struct ChatCompletionChunkResponse {
Expand All @@ -155,23 +137,15 @@ pub struct ChatCompletionChunkResponse {
pub object: String,
}

generate_repr!(ChatCompletionChunkResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Completion request choice.
pub struct CompletionChoice {
pub finish_reason: String,
pub finish_reason: StopReason,
pub index: usize,
pub text: String,
pub logprobs: Option<()>,
}

generate_repr!(CompletionChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// An OpenAI compatible completion response.
pub struct CompletionResponse {
Expand All @@ -184,10 +158,6 @@ pub struct CompletionResponse {
pub usage: Usage,
}

generate_repr!(CompletionResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Completion request choice.
pub struct CompletionChunkResponse {
Expand All @@ -199,8 +169,6 @@ pub struct CompletionChunkResponse {
pub object: String,
}

generate_repr!(CompletionChunkResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde::Serialize;
use std::{
fmt::Display,
sync::{Arc, RwLock},
Expand Down Expand Up @@ -27,7 +28,7 @@ use crate::{
use candle_core::Tensor;
use regex_automata::util::primitives::StateID;

#[derive(Clone, Copy, PartialEq, Debug)]
#[derive(Clone, Copy, PartialEq, Debug, Serialize)]
pub enum StopReason {
Eos,
StopTok(u32),
Expand All @@ -39,6 +40,7 @@ pub enum StopReason {
},
Canceled,
GeneratedImage,
Error,
}

impl Display for StopReason {
Expand All @@ -49,6 +51,7 @@ impl Display for StopReason {
StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
StopReason::Canceled => write!(f, "canceled"),
StopReason::GeneratedImage => write!(f, "generated-image"),
StopReason::Error => write!(f, "stop"),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ macro_rules! handle_pipeline_forward_error {

if seq.get_mut_group().is_chat {
let choice = Choice {
finish_reason: "error".to_string(),
finish_reason: StopReason::Error,
index: seq.get_response_index(),
message: ResponseMessage {
content: Some(res),
Expand All @@ -120,7 +120,7 @@ macro_rules! handle_pipeline_forward_error {
seq.add_choice_to_group(choice);
} else {
let choice = CompletionChoice {
finish_reason: "error".to_string(),
finish_reason: StopReason::Error,
index: seq.get_response_index(),
text: res,
logprobs: None,
Expand Down
70 changes: 53 additions & 17 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType};
use either::Either;
use indexmap::IndexMap;
use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice};
use response::{ChatCompletionResponse, Choice, CompletionChoice, CompletionResponse};
use std::{
cell::RefCell,
collections::HashMap,
Expand All @@ -18,20 +19,21 @@ use util::{PyApiErr, PyApiResult};

use candle_core::{Device, Result};
use mistralrs_core::{
initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader,
ChatCompletionResponse, CompletionResponse, Constraint, DefaultSchedulerMethod,
DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams, DiffusionLoaderBuilder,
DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder, GGMLSpecificConfig,
GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse, ImageGenerationResponseFormat,
Loader, MemoryGpuConfig, MistralRs, MistralRsBuilder, NormalLoaderBuilder, NormalRequest,
NormalSpecificConfig, PagedAttentionConfig, Request as _Request, RequestMessage, Response,
ResponseOk, SamplingParams, SchedulerConfig, SpeculativeConfig, SpeculativeLoader, StopTokens,
TokenSource, Tool, Topology, VisionLoaderBuilder, VisionSpecificConfig,
initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader, Constraint,
DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams,
DiffusionLoaderBuilder, DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder,
GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse,
ImageGenerationResponseFormat, Loader, MemoryGpuConfig, MistralRs, MistralRsBuilder,
NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Request as _Request, RequestMessage, Response, ResponseOk, SamplingParams, SchedulerConfig,
SpeculativeConfig, SpeculativeLoader, StopTokens, TokenSource, Tool, Topology,
VisionLoaderBuilder, VisionSpecificConfig,
};
use pyo3::prelude::*;
use std::fs::File;
mod anymoe;
mod requests;
mod response;
mod stream;
mod util;
mod which;
Expand Down Expand Up @@ -877,7 +879,24 @@ impl Runner {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyApiErr::from(e.to_string()))
}
Response::Done(response) => Ok(Either::Left(response)),
Response::Done(response) => Ok(Either::Left(ChatCompletionResponse {
id: response.id,
created: response.created,
model: response.model,
system_fingerprint: response.system_fingerprint,
object: response.object,
usage: response.usage,
choices: response
.choices
.into_iter()
.map(|choice| Choice {
finish_reason: choice.finish_reason.to_string(),
index: choice.index,
message: choice.message,
logprobs: choice.logprobs,
})
.collect(),
})),
Response::ModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::CompletionDone(_) => unreachable!(),
Expand Down Expand Up @@ -996,7 +1015,24 @@ impl Runner {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyApiErr::from(e.to_string()))
}
Response::CompletionDone(response) => Ok(response),
Response::CompletionDone(response) => Ok(CompletionResponse {
id: response.id,
created: response.created,
model: response.model,
system_fingerprint: response.system_fingerprint,
object: response.object,
usage: response.usage,
choices: response
.choices
.into_iter()
.map(|choice| CompletionChoice {
finish_reason: choice.finish_reason.to_string(),
index: choice.index,
text: choice.text,
logprobs: choice.logprobs,
})
.collect(),
}),
Response::CompletionModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::Done(_) => unreachable!(),
Expand Down Expand Up @@ -1093,13 +1129,13 @@ fn mistralrs(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<mistralrs_core::Delta>()?;
m.add_class::<mistralrs_core::ResponseLogprob>()?;
m.add_class::<mistralrs_core::Logprobs>()?;
m.add_class::<mistralrs_core::Choice>()?;
m.add_class::<mistralrs_core::ChunkChoice>()?;
m.add_class::<response::Choice>()?;
m.add_class::<response::ChunkChoice>()?;
m.add_class::<mistralrs_core::Usage>()?;
m.add_class::<mistralrs_core::ChatCompletionResponse>()?;
m.add_class::<mistralrs_core::ChatCompletionChunkResponse>()?;
m.add_class::<mistralrs_core::CompletionChoice>()?;
m.add_class::<mistralrs_core::CompletionResponse>()?;
m.add_class::<response::ChatCompletionResponse>()?;
m.add_class::<response::ChatCompletionChunkResponse>()?;
m.add_class::<response::CompletionChoice>()?;
m.add_class::<response::CompletionResponse>()?;
m.add_class::<mistralrs_core::TopLogprob>()?;
m.add_class::<mistralrs_core::ModelDType>()?;
m.add_class::<mistralrs_core::ImageGenerationResponseFormat>()?;
Expand Down
Loading
Loading