Skip to content

Commit

Permalink
Calculate and return logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 18, 2023
1 parent f399b80 commit f62de2f
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 151 deletions.
285 changes: 174 additions & 111 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ candle-lora-transformers = { git = "https://github.com/EricLBuehler/candle-lora.
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
dyn-fmt = "0.4.0"
serde = { version = "1.0.190", features = ["serde_derive"] }
tokenizers = "0.13.4"
tokenizers = "0.15.0"
uuid = { version = "1.5.0", features = ["v4"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
hf-hub = "0.3.2"
Expand All @@ -28,7 +28,7 @@ cudarc = { version = "0.9.14", features = ["f16"], optional = true }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0", optional = true }
clap = { version = "4.4.7", features = ["derive"] }
candle-sampling = { git = "https://github.com/EricLBuehler/candle-sampling.git", version = "0.1.0" }
candle-sampling = { git = "https://github.com/EricLBuehler/candle-sampling.git", version = "0.2.0" }
futures = "0.3.29"
tokio = { version = "1.33.0", features = ["sync"] }
env_logger = "0.10.1"
Expand Down
13 changes: 8 additions & 5 deletions src/openai/pipelines/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline {
) -> Result<Vec<TokenOrFinishReason>, APIError> {
let eos_token_id = self.tokenizer.token_to_id(EOS_TOKEN);

let mut logits_processor = sampling_params.get_logits_processor(SAMPLING_SEED);
let mut logits_processor = sampling_params.get_logits_processor(
SAMPLING_SEED,
&self.tokenizer,
sampling_params.logprobs.unwrap_or(1),
);
let stop_tokens = match sampling_params.stop.clone() {
Some(stop) => match stop {
StopTokens::Multi(multi) => multi,
Expand Down Expand Up @@ -241,24 +245,23 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline {
};

let next_token = logits_processor.sample(&logits).map_err(APIError::from)?;
if let Some(text) = self.tokenizer.id_to_token(next_token) {
if let Some(text) = self.tokenizer.id_to_token(next_token.token as u32) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
if stop_tokens.contains(&text) {
result.push(Right("stop".to_string()));
continue;
}
}

if Some(next_token) == eos_token_id {
if Some(next_token.token) == eos_token_id.map(|x| x as usize) {
result.push(Right("stop".to_string()));
continue;
}
if tokens_generated >= sampling_params.max_tokens {
result.push(Right("length".to_string()));
continue;
}
// TODO(EricLBuehler): Actually compute logprobs
result.push(Left((next_token as usize, 0.)));
result.push(Left(next_token));
}

Ok(result)
Expand Down
19 changes: 11 additions & 8 deletions src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use tokenizers::Encoding;

use crate::{
openai::{
responses::{APIError, ChatChoice, ChatChoiceData, ChatCompletionUsageResponse},
responses::{
APIError, ChatChoice, ChatChoiceData, ChatCompletionUsageResponse, WrapperLogprobs,
},
sampling_params::SamplingParams,
utils::get_created_time_secs,
},
Expand Down Expand Up @@ -157,8 +159,8 @@ impl<'a> LLMEngine<'a> {

for (result, (_, seq)) in zip(result, seqs) {
match result {
Either::Left((token, logprob)) => {
seq.deref_mut().add_token(token, logprob);
Either::Left(logprobs) => {
seq.deref_mut().add_token(logprobs);
}
Either::Right(finish_reason) => {
seq.deref_mut().set_finish_reason(finish_reason)
Expand All @@ -183,11 +185,10 @@ impl<'a> LLMEngine<'a> {

let mut choices = Vec::new();
for (index, seq) in top_n.iter().enumerate() {
let data = seq
.deref_mut()
.get_token_ids()
let outputs = seq.deref_mut().get_output_tokens();
let data = outputs
.iter()
.map(|x| (*x).try_into().unwrap())
.map(|x| x.token.try_into().unwrap())
.collect::<Vec<_>>();
let data = self.pipeline.tokenizer().detokenize(&data)?;
let choice = ChatChoice {
Expand All @@ -197,7 +198,9 @@ impl<'a> LLMEngine<'a> {
},
finish_reason: Some(seq.deref_mut().get_finish_reason().clone()),
index,
logprobs: None, // TODO(EricLBuehler): actually add this
logprobs: Some(WrapperLogprobs {
content: seq.deref_mut().get_output_tokens(),
}),
};
choices.push(choice);
}
Expand Down
3 changes: 2 additions & 1 deletion src/openai/pipelines/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{env, path::PathBuf, sync::Arc};

use candle_core::{DType, Device, Tensor, WithDType};
use candle_sampling::logits_processor::Logprobs;
use either::Either;

use crate::{paged_attention::input_metadata::InputMetadata, scheduler::sequence::Sequence};
Expand All @@ -15,7 +16,7 @@ pub mod llama;
/// which are used to scheduler and manage the cache during generation requests, respectively.
pub mod llm_engine;

type TokenOrFinishReason = Either<(usize, f32), String>;
type TokenOrFinishReason = Either<Logprobs, String>;

pub trait ModulePipeline<'s>: Send + Sync {
fn forward(
Expand Down
17 changes: 4 additions & 13 deletions src/openai/responses.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use actix_web::error;
use candle_sampling::logits_processor::Logprobs;
use derive_more::{Display, Error};

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -43,26 +44,16 @@ pub struct ChatChoiceData {
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopLogprob {
token: usize,
logprob: f32,
bytes: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Logprobs {
token: usize,
logprob: f32,
bytes: String,
top_logprobs: Vec<TopLogprob>,
pub struct WrapperLogprobs {
pub content: Vec<Logprobs>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub message: ChatChoiceData,
pub finish_reason: Option<String>,
pub index: usize,
pub logprobs: Option<Logprobs>,
pub logprobs: Option<WrapperLogprobs>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
17 changes: 15 additions & 2 deletions src/openai/sampling_params.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ops::Range;

use candle_sampling::logits_processor::{LogitsProcessor, SamplingMethod};
use tokenizers::Tokenizer;

use super::{requests::StopTokens, responses::APIError};

Expand Down Expand Up @@ -64,10 +65,11 @@ pub struct SamplingParams {
/// Whether to ignore EOS token.
pub ignore_eos: bool,
/// Max number of toks to gen per output seq.
///rec. default = 16
/// rec. default = 16
pub max_tokens: usize,
/// Num of log probs to return per output token. Follows OpenAI API, return result include the log probabilities on the `logprobs` most likely tokens.
/// will always return the log prob of the sampled token, so there may be up to `logprobs+1` elements in the response.
/// Default = 1
pub logprobs: Option<usize>,
/// Num of log probs to return per prompt token.
pub prompt_logprobs: Option<usize>,
Expand Down Expand Up @@ -132,27 +134,38 @@ impl SamplingParams {
Ok(this)
}

pub fn get_logits_processor(&self, seed: u64) -> LogitsProcessor {
pub fn get_logits_processor<'a>(
&self,
seed: u64,
tokenizer: &'a Tokenizer,
top_n_logprobs: usize,
) -> LogitsProcessor<'a> {
if self.top_k == -1 && self.top_p == 1. {
// Greedy
LogitsProcessor::new(
seed,
Some(self.temperature.into()),
SamplingMethod::Multinomial,
top_n_logprobs,
tokenizer,
)
} else if self.top_k > 0 && self.top_p == 1. {
// Top-k
LogitsProcessor::new(
seed,
Some(self.temperature.into()),
SamplingMethod::TopK(self.top_k.try_into().unwrap()),
top_n_logprobs,
tokenizer,
)
} else if self.top_k == -1 && self.top_p != 1. {
// Top-p
LogitsProcessor::new(
seed,
Some(self.temperature.into()),
SamplingMethod::TopP(self.top_p.into()),
top_n_logprobs,
tokenizer,
)
} else {
unreachable!()
Expand Down
30 changes: 21 additions & 9 deletions src/scheduler/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::{
sync::{Arc, Mutex, MutexGuard},
};

use candle_sampling::logits_processor::Logprobs;

use super::block_engine::LogicalTokenBlock;

#[derive(Clone)]
Expand All @@ -17,7 +19,7 @@ pub enum SequenceStatus {

pub struct SequenceData {
prompt_token_ids: Vec<usize>,
output_token_ids: Vec<(usize, f32)>,
output_token_ids: Vec<Logprobs>,
cumulative_logprob: f32,
status: SequenceStatus,
}
Expand All @@ -32,9 +34,9 @@ impl SequenceData {
}
}

pub fn append_token_id(&mut self, token_id: usize, logprob: f32) {
self.output_token_ids.push((token_id, logprob));
self.cumulative_logprob += logprob;
pub fn append_token_id(&mut self, logprobs: Logprobs) {
self.cumulative_logprob += logprobs.logprob;
self.output_token_ids.push(logprobs);
}

pub fn set_status(&mut self, status: SequenceStatus) {
Expand Down Expand Up @@ -67,9 +69,9 @@ impl _Sequence {
this
}

pub fn add_token(&mut self, token: usize, logprob: f32) {
self.deref_mut().append_token_id(token, logprob);
self.append_token_to_blocks(token);
pub fn add_token(&mut self, logprobs: Logprobs) {
self.append_token_to_blocks(logprobs.token);
self.deref_mut().append_token_id(logprobs);
}

pub fn blocks_to_add_new_tok(&mut self) -> usize {
Expand Down Expand Up @@ -104,15 +106,21 @@ impl _Sequence {

pub fn get_token_ids(&self) -> Vec<usize> {
let mut res = self.deref().prompt_token_ids.clone();
res.extend(self.deref().output_token_ids.iter().map(|(x, _)| x).clone());
res.extend(
self.deref()
.output_token_ids
.iter()
.map(|logprobs| logprobs.token)
.clone(),
);
res
}

pub fn get_last_token_id(&self) -> usize {
if self.deref().output_token_ids.is_empty() {
*self.deref().prompt_token_ids.last().unwrap()
} else {
self.deref().output_token_ids.last().unwrap().0
self.deref().output_token_ids.last().unwrap().token
}
}

Expand Down Expand Up @@ -145,6 +153,10 @@ impl _Sequence {
}
}

pub fn get_output_tokens(&self) -> Vec<Logprobs> {
self.deref().output_token_ids.clone() // TODO(EricLBuehler): Better way to do this?
}

fn append_tokens_to_blocks(&mut self, tokens: Vec<usize>) {
for tok in tokens {
self.append_token_to_blocks(tok);
Expand Down

0 comments on commit f62de2f

Please sign in to comment.