Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Add "context swap" functions to session and add "decoded_tokens" to snapshot read/write #424

Merged
merged 3 commits into from
Nov 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 188 additions & 22 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,22 @@ impl InferenceSession {
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let beginning_of_sentence = self.n_past == 0;

let vocab = model.tokenizer();
let prompt_tokens = prompt.into().to_tokens(vocab, beginning_of_sentence)?;
let prompt_tokens = self.get_prompt_tokens(model, prompt)?;

if self.n_past + prompt_tokens.len() >= model.context_size() {
return Err(InferenceError::ContextFull);
}

self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens)
}

fn feed_prompt_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
prompt_tokens: Vec<TokenId>,
) -> Result<(), InferenceError> {
'outer: for batch in prompt_tokens.chunks(self.config.n_batch) {
model.evaluate(self, batch, output_request);
for &tk in batch {
Expand Down Expand Up @@ -414,10 +421,46 @@ impl InferenceSession {
}
}
log::trace!("Finished feed prompt");

Ok(())
}

fn get_prompt_tokens<'a, P: Into<Prompt<'a>>>(
&self,
model: &dyn Model,
prompt: P,
) -> Result<Vec<TokenId>, TokenizationError> {
let beginning_of_sentence = self.n_past == 0;

let vocab = model.tokenizer();
prompt.into().to_tokens(vocab, beginning_of_sentence)
}

/// Feed a prompt to the model for this session.
/// Same as [Self::feed_prompt] but includes logic for cutting tokens in case if the prompt is longer than current n_past.
#[instrument(skip_all)]
pub fn feed_prompt_with_swap<
'a,
E: std::error::Error + Send + Sync + 'static,
P: Into<Prompt<'a>>,
>(
&mut self,
model: &dyn Model,
prompt: P,
n_keep: usize,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let prompt_tokens = self.get_prompt_tokens(model, prompt)?;

if self.n_past + prompt_tokens.len() >= model.context_size() {
let rewind_by = self.n_past + prompt_tokens.len() - model.context_size();
self.remove_tokens(model, n_keep, rewind_by)
.map_err(|_e| InferenceError::ContextFull)?;
}

self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens)
}

/// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`.
pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result<Vec<TokenId>, RewindError> {
if !model.supports_rewind() {
Expand Down Expand Up @@ -445,6 +488,46 @@ impl InferenceSession {
Ok(deleted_tokens)
}

/// Removes `num` tokens from the specified position of the buffer. Similar to [Self::rewind].
fn remove_tokens(
&mut self,
model: &dyn Model,
start_from: usize,
num: usize,
) -> Result<Vec<TokenId>, RewindError> {
if !model.supports_rewind() {
return Err(RewindError::UnsupportedArchitecture);
}

if start_from + num >= self.n_past {
return Err(RewindError::NotEnoughTokens);
}

// Remove the tokens from self.tokens.
let end = start_from + num;
let deleted_tokens: Vec<_> = self.tokens.drain(start_from..end).collect();

// Remove the corresponding chars from decoded
let mut decoded_start = 0;
let mut decoded_end = 0;
if start_from != 0 {
for id in &self.tokens[0..start_from] {
decoded_start += model.tokenizer().token(*id as usize).len();
}
decoded_end += decoded_start;
}

for id in &deleted_tokens {
decoded_end += model.tokenizer().token(*id as usize).len();
}
self.decoded_tokens.drain(decoded_start..decoded_end);

// Decrement the n_past tokens counter.
self.n_past -= num;

Ok(deleted_tokens)
}

/// Infer the next token for this session.
#[instrument(level = "trace", skip_all)]
pub fn infer_next_token(
Expand Down Expand Up @@ -510,19 +593,7 @@ impl InferenceSession {
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
if request.play_back_previous_tokens {
// "Play back" the existing tokens, so that loading from an inference snapshot works
// as expected.
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(&model.tokenizer().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
self.play_back_previous_tokens(model, &mut callback)?
}
log::trace!(
"Starting inference request with max_token_count: {}",
Expand All @@ -547,10 +618,25 @@ impl InferenceSession {
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

// After the prompt is consumed, sample tokens by repeatedly calling
// `infer_next_token`. We generate tokens until the model returns an
// EndOfText token, or we run out of space in the context window,
// or we reach the specified limit.
self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?;
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;

Ok(stats)
}

/// sample tokens by repeatedly calling
/// [Self::infer_next_token]. Generate tokens until the model returns an
/// EndOfText token, or we run out of space in the context window,
/// or we reach the specified limit.
fn infer_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
maximum_token_count: usize,
parameters: &InferenceParameters,
) -> Result<(), InferenceError> {
let mut tokens_processed = 0;
let mut token_utf8_buf = TokenUtf8Buffer::new();
while tokens_processed < maximum_token_count {
Expand All @@ -574,6 +660,79 @@ impl InferenceSession {

tokens_processed += 1;
}
Ok(())
}

/// "Play back" the existing tokens, so that loading from an inference snapshot works
/// as expected.
fn play_back_previous_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(&model.tokenizer().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
Ok(())
}

/// Generate text by using the provided [Model] to evaluate the `prompt`.
/// Works the same way as [Self::infer] except has infinite text generation via context swapping
#[instrument(skip_all)]
pub fn infer_with_swap<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
request: &InferenceRequest,
n_keep: usize,
output_request: &mut OutputRequest,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
if request.play_back_previous_tokens {
self.play_back_previous_tokens(model, &mut callback)?
}

// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt
// - remove half of the tokens after n_keep ((n_ctx - n_keep) / 2)
if self.n_past >= model.context_size() {
self.remove_tokens(model, n_keep, (self.n_past - n_keep) / 2)
.map_err(|_e| InferenceError::ContextFull)?;
}

log::trace!(
"Starting inference request with max_token_count: {}",
maximum_token_count
);

let mut stats = InferenceStats::default();
let start_at = std::time::SystemTime::now();

let parameters = request.parameters;

// Feed the initial prompt through the transformer, to update its
// context window with new data, if necessary.
if !request.prompt.is_empty() {
self.feed_prompt(
model,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
}
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?;
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;

Expand Down Expand Up @@ -677,6 +836,7 @@ impl InferenceSession {
npast: self.n_past,
config: self.config,
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
memory_k,
memory_v,
Expand Down Expand Up @@ -709,6 +869,7 @@ impl InferenceSession {

session.n_past = snapshot.npast;
session.tokens = snapshot.tokens;
session.decoded_tokens = snapshot.decoded_tokens;
session.last_logits = snapshot.last_logits;

Ok(session)
Expand Down Expand Up @@ -814,6 +975,8 @@ pub struct InferenceSnapshotRef<'a> {
pub config: InferenceSessionConfig,
/// All tokens generated by this inference session.
pub tokens: Vec<TokenId>,
/// All decoded tokens generated by this inference session.
pub decoded_tokens: Vec<u8>,
/// The vector of logits that was produced after the last inference.
pub last_logits: Vec<f32>,
/// The contents of the 'key' memory tensor.
Expand All @@ -832,6 +995,7 @@ impl InferenceSnapshotRef<'_> {
npast: self.npast,
config: self.config,
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
memory_k: self.memory_k.to_vec(),
memory_v: self.memory_v.to_vec(),
Expand All @@ -850,6 +1014,8 @@ pub struct InferenceSnapshot {
pub config: InferenceSessionConfig,
/// All tokens generated by this inference session.
pub tokens: Vec<TokenId>,
/// All decoded tokens generated by this inference session.
pub decoded_tokens: Vec<u8>,
/// The vector of logits that was produced after the last inference.
pub last_logits: Vec<f32>,
/// The contents of the 'key' memory tensor.
Expand Down
Loading