Skip to content

Commit

Permalink
Return stop reason enum
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 28, 2024
1 parent 776c116 commit db60ae4
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 65 deletions.
1 change: 1 addition & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub use sampler::{
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
};
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
pub use sequence::StopReason;
use serde::Serialize;
use tokio::runtime::Runtime;
use toml_selector::{TomlLoaderArgs, TomlSelector};
Expand Down
11 changes: 7 additions & 4 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
role: "assistant".to_string(),
},
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
finish_reason: is_done,
logprobs: if seq.return_logprobs() {
Some(crate::ResponseLogprob {
token: delta,
Expand All @@ -66,7 +66,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
crate::CompletionChunkChoice {
text: delta.clone(),
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
finish_reason: is_done,
logprobs: if seq.return_logprobs() {
Some(crate::ResponseLogprob {
token: delta,
Expand Down Expand Up @@ -162,6 +162,9 @@ pub(crate) async fn finish_or_add_toks_to_seq(
crate::sequence::StopReason::GeneratedImage => {
candle_core::bail!("Stop reason was `GeneratedImage`.")
}
crate::sequence::StopReason::Error => {
candle_core::bail!("Stop reason was `Error`.")
}
};

if seq.get_mut_group().is_chat {
Expand All @@ -175,7 +178,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
tool_calls = calls;
}
let choice = crate::Choice {
finish_reason: reason.to_string(),
finish_reason: reason,
index: seq.get_response_index(),
message: crate::ResponseMessage {
content: text_new,
Expand All @@ -187,7 +190,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
seq.add_choice_to_group(choice);
} else {
let choice = crate::CompletionChoice {
finish_reason: reason.to_string(),
finish_reason: reason,
index: seq.get_response_index(),
text,
logprobs: None,
Expand Down
42 changes: 5 additions & 37 deletions mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use pyo3::{pyclass, pymethods};
use serde::Serialize;

use crate::{sampler::TopLogprob, tools::ToolCallResponse};
use crate::{sampler::TopLogprob, sequence::StopReason, tools::ToolCallResponse};

pub const SYSTEM_FINGERPRINT: &str = "local";

Expand Down Expand Up @@ -69,45 +69,33 @@ pub struct Logprobs {

generate_repr!(Logprobs);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion choice.
pub struct Choice {
pub finish_reason: String,
pub finish_reason: StopReason,
pub index: usize,
pub message: ResponseMessage,
pub logprobs: Option<Logprobs>,
}

generate_repr!(Choice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming chunk choice.
pub struct ChunkChoice {
pub finish_reason: Option<String>,
pub finish_reason: Option<StopReason>,
pub index: usize,
pub delta: Delta,
pub logprobs: Option<ResponseLogprob>,
}

generate_repr!(ChunkChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming chunk choice.
pub struct CompletionChunkChoice {
pub text: String,
pub index: usize,
pub logprobs: Option<ResponseLogprob>,
pub finish_reason: Option<String>,
pub finish_reason: Option<StopReason>,
}

generate_repr!(CompletionChunkChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
Expand All @@ -126,8 +114,6 @@ pub struct Usage {

generate_repr!(Usage);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// An OpenAI compatible chat completion response.
pub struct ChatCompletionResponse {
Expand All @@ -140,10 +126,6 @@ pub struct ChatCompletionResponse {
pub usage: Usage,
}

generate_repr!(ChatCompletionResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Chat completion streaming request chunk.
pub struct ChatCompletionChunkResponse {
Expand All @@ -155,23 +137,15 @@ pub struct ChatCompletionChunkResponse {
pub object: String,
}

generate_repr!(ChatCompletionChunkResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Completion request choice.
pub struct CompletionChoice {
pub finish_reason: String,
pub finish_reason: StopReason,
pub index: usize,
pub text: String,
pub logprobs: Option<()>,
}

generate_repr!(CompletionChoice);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// An OpenAI compatible completion response.
pub struct CompletionResponse {
Expand All @@ -184,10 +158,6 @@ pub struct CompletionResponse {
pub usage: Usage,
}

generate_repr!(CompletionResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
/// Completion request choice.
pub struct CompletionChunkResponse {
Expand All @@ -199,8 +169,6 @@ pub struct CompletionChunkResponse {
pub object: String,
}

generate_repr!(CompletionChunkResponse);

#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde::Serialize;
use std::{
fmt::Display,
sync::{Arc, RwLock},
Expand Down Expand Up @@ -27,7 +28,7 @@ use crate::{
use candle_core::Tensor;
use regex_automata::util::primitives::StateID;

#[derive(Clone, Copy, PartialEq, Debug)]
#[derive(Clone, Copy, PartialEq, Debug, Serialize)]
pub enum StopReason {
Eos,
StopTok(u32),
Expand All @@ -39,6 +40,7 @@ pub enum StopReason {
},
Canceled,
GeneratedImage,
Error,
}

impl Display for StopReason {
Expand All @@ -49,6 +51,7 @@ impl Display for StopReason {
StopReason::StopTok(_) | StopReason::StopString { .. } => write!(f, "stop"),
StopReason::Canceled => write!(f, "canceled"),
StopReason::GeneratedImage => write!(f, "generated-image"),
StopReason::Error => write!(f, "stop"),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ macro_rules! handle_pipeline_forward_error {

if seq.get_mut_group().is_chat {
let choice = Choice {
finish_reason: "error".to_string(),
finish_reason: StopReason::Error,
index: seq.get_response_index(),
message: ResponseMessage {
content: Some(res),
Expand All @@ -120,7 +120,7 @@ macro_rules! handle_pipeline_forward_error {
seq.add_choice_to_group(choice);
} else {
let choice = CompletionChoice {
finish_reason: "error".to_string(),
finish_reason: StopReason::Error,
index: seq.get_response_index(),
text: res,
logprobs: None,
Expand Down
70 changes: 53 additions & 17 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType};
use either::Either;
use indexmap::IndexMap;
use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice};
use response::{ChatCompletionResponse, Choice, CompletionChoice, CompletionResponse};
use std::{
cell::RefCell,
collections::HashMap,
Expand All @@ -18,20 +19,21 @@ use util::{PyApiErr, PyApiResult};

use candle_core::{Device, Result};
use mistralrs_core::{
initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader,
ChatCompletionResponse, CompletionResponse, Constraint, DefaultSchedulerMethod,
DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams, DiffusionLoaderBuilder,
DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder, GGMLSpecificConfig,
GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse, ImageGenerationResponseFormat,
IsqOrganization, Loader, MemoryGpuConfig, MistralRs, MistralRsBuilder, NormalLoaderBuilder,
NormalRequest, NormalSpecificConfig, PagedAttentionConfig, Request as _Request, RequestMessage,
Response, ResponseOk, SamplingParams, SchedulerConfig, SpeculativeConfig, SpeculativeLoader,
StopTokens, TokenSource, Tool, Topology, VisionLoaderBuilder, VisionSpecificConfig,
initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader, Constraint,
DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DiffusionGenerationParams,
DiffusionLoaderBuilder, DiffusionSpecificConfig, DrySamplingParams, GGMLLoaderBuilder,
GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, ImageGenerationResponse,
ImageGenerationResponseFormat, IsqOrganization, Loader, MemoryGpuConfig, MistralRs,
MistralRsBuilder, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig,
PagedAttentionConfig, Request as _Request, RequestMessage, Response, ResponseOk,
SamplingParams, SchedulerConfig, SpeculativeConfig, SpeculativeLoader, StopTokens, TokenSource,
Tool, Topology, VisionLoaderBuilder, VisionSpecificConfig,
};
use pyo3::prelude::*;
use std::fs::File;
mod anymoe;
mod requests;
mod response;
mod stream;
mod util;
mod which;
Expand Down Expand Up @@ -880,7 +882,24 @@ impl Runner {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyApiErr::from(e.to_string()))
}
Response::Done(response) => Ok(Either::Left(response)),
Response::Done(response) => Ok(Either::Left(ChatCompletionResponse {
id: response.id,
created: response.created,
model: response.model,
system_fingerprint: response.system_fingerprint,
object: response.object,
usage: response.usage,
choices: response
.choices
.into_iter()
.map(|choice| Choice {
finish_reason: choice.finish_reason.to_string(),
index: choice.index,
message: choice.message,
logprobs: choice.logprobs,
})
.collect(),
})),
Response::ModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::CompletionDone(_) => unreachable!(),
Expand Down Expand Up @@ -999,7 +1018,24 @@ impl Runner {
Response::ValidationError(e) | Response::InternalError(e) => {
Err(PyApiErr::from(e.to_string()))
}
Response::CompletionDone(response) => Ok(response),
Response::CompletionDone(response) => Ok(CompletionResponse {
id: response.id,
created: response.created,
model: response.model,
system_fingerprint: response.system_fingerprint,
object: response.object,
usage: response.usage,
choices: response
.choices
.into_iter()
.map(|choice| CompletionChoice {
finish_reason: choice.finish_reason.to_string(),
index: choice.index,
text: choice.text,
logprobs: choice.logprobs,
})
.collect(),
}),
Response::CompletionModelError(msg, _) => Err(PyApiErr::from(msg.to_string())),
Response::Chunk(_) => unreachable!(),
Response::Done(_) => unreachable!(),
Expand Down Expand Up @@ -1096,13 +1132,13 @@ fn mistralrs(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<mistralrs_core::Delta>()?;
m.add_class::<mistralrs_core::ResponseLogprob>()?;
m.add_class::<mistralrs_core::Logprobs>()?;
m.add_class::<mistralrs_core::Choice>()?;
m.add_class::<mistralrs_core::ChunkChoice>()?;
m.add_class::<response::Choice>()?;
m.add_class::<response::ChunkChoice>()?;
m.add_class::<mistralrs_core::Usage>()?;
m.add_class::<mistralrs_core::ChatCompletionResponse>()?;
m.add_class::<mistralrs_core::ChatCompletionChunkResponse>()?;
m.add_class::<mistralrs_core::CompletionChoice>()?;
m.add_class::<mistralrs_core::CompletionResponse>()?;
m.add_class::<response::ChatCompletionResponse>()?;
m.add_class::<response::ChatCompletionChunkResponse>()?;
m.add_class::<response::CompletionChoice>()?;
m.add_class::<response::CompletionResponse>()?;
m.add_class::<mistralrs_core::TopLogprob>()?;
m.add_class::<mistralrs_core::ModelDType>()?;
m.add_class::<mistralrs_core::ImageGenerationResponseFormat>()?;
Expand Down
Loading

0 comments on commit db60ae4

Please sign in to comment.