diff --git a/binaries/llm-test/configs/bloom.json b/binaries/llm-test/configs/bloom.json index b8c4bc2a..5383386d 100644 --- a/binaries/llm-test/configs/bloom.json +++ b/binaries/llm-test/configs/bloom.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 15 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/gptj.json b/binaries/llm-test/configs/gptj.json index 85493b5c..50966748 100644 --- a/binaries/llm-test/configs/gptj.json +++ b/binaries/llm-test/configs/gptj.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 257 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/gptneox.json b/binaries/llm-test/configs/gptneox.json index 3e6b84cb..c8cce4d9 100644 --- a/binaries/llm-test/configs/gptneox.json +++ b/binaries/llm-test/configs/gptneox.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 247 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/llama.json b/binaries/llm-test/configs/llama.json index 1e2f23f4..9bd6094a 100644 --- a/binaries/llm-test/configs/llama.json +++ b/binaries/llm-test/configs/llama.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 260 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 7142c143..57a8bc89 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 247 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/src/delete.rs b/binaries/llm-test/src/delete.rs new file mode 100644 index 00000000..2c609cb5 --- /dev/null +++ b/binaries/llm-test/src/delete.rs @@ -0,0 +1,100 @@ +//! Tests the model's token manipulation APIs: +//! +//! * [llm::InferenceSession::feed_prompt()] +//! +//! See [crate::TestCase::Tokens]. + +use std::convert::Infallible; + +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; +use serde::Serialize; + +use crate::{TestCaseReport, TestCaseReportMeta}; + +/// Error tolerance for the float comparisons. +const TOLERANCE: f32 = 1e-7; + +/// Tests that models can delete tokens without changing the model's behavior. +pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { + let report = DeleteReport::default(); + let mut session = model.start_session(Default::default()); + let mut output = OutputRequest { + all_logits: Some(vec![]), + ..Default::default() + }; + + // Feed some tokens + if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + + // Add token and get the logits + if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(original_logits) = output.all_logits.clone() else { + return report.failure("Model did not return logits."); + }; + + // Delete, then re-add. Verify logits are the same. + if let Err(err) = session.delete_tokens(model, 1) { + return report.failure(&err.to_string()); + } + if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(redone_logits) = output.all_logits.clone() else { + return report.failure("Second run of model did not return logits."); + }; + + // Compare the logits + for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() { + if original > redone + TOLERANCE || original < redone - TOLERANCE { + return report.failure(&format!( + "Expected logits to be the same after delete, but differed at {idx}, \ + expected {original}, but was {redone}." + )); + } + } + + log::info!("`can_delete` test passed (no expected output)!"); + report.success() +} + +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, |x| { + always_continue(x) + }) +} + +fn always_continue(_: &[u8]) -> Result { + Ok(InferenceFeedback::Continue) +} + +#[derive(Serialize, Default)] +pub struct DeleteReport { + output: usize, +} + +impl DeleteReport { + fn failure(self, msg: &str) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Error { + error: msg.to_owned(), + }, + report: crate::TestCaseReportInner::Delete(self), + } + } + + fn success(self) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Success, + report: crate::TestCaseReportInner::Delete(self), + } + } +} diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index b14ed78a..b1bc9b07 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -1,6 +1,7 @@ //! Test runner for all LLMs. mod common; +mod delete; mod inference; mod tokens; @@ -128,6 +129,7 @@ enum TestCase { input: String, output: usize, }, + Delete {}, } #[derive(Serialize)] @@ -158,6 +160,7 @@ pub enum TestCaseReportInner { inference_stats: Option, }, Tokens(tokens::TokensReport), + Delete(delete::DeleteReport), } async fn test_model( @@ -278,6 +281,9 @@ async fn test_model( TestCase::Tokens { input, output } => { test_case_reports.push(tokens::can_feed(&model, input, *output)); } + TestCase::Delete {} => { + test_case_reports.push(delete::can_delete(&model)); + } } } let first_error: Option = diff --git a/binaries/llm-test/src/tokens.rs b/binaries/llm-test/src/tokens.rs index 52f37019..a2a490b6 100644 --- a/binaries/llm-test/src/tokens.rs +++ b/binaries/llm-test/src/tokens.rs @@ -6,7 +6,7 @@ use std::convert::Infallible; -use llm::{InferenceFeedback, Model, OutputRequest}; +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; use serde::Serialize; use crate::{TestCaseReport, TestCaseReportMeta}; @@ -14,20 +14,13 @@ use crate::{TestCaseReport, TestCaseReportMeta}; /// Tests that the model performs as expected when feeding tokens pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport { let mut report = TokensReport::default(); - let mut session = model.start_session(Default::default()); let mut output = OutputRequest { all_logits: Some(vec![]), ..Default::default() }; - let feed_prompt = &mut |prompt: &str| { - session.feed_prompt(model, &Default::default(), prompt, &mut output, |x| { - always_continue(x) - }) - }; - - if let Err(err) = feed_prompt(input) { + if let Err(err) = feed_prompt(input, &mut session, model, &mut output) { return report.failure(&err.to_string()); }; @@ -62,9 +55,21 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) )); } + log::info!("`can_feed` test passed (no expected output)!"); report.success() } +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, |x| { + always_continue(x) + }) +} + fn always_continue(_: &[u8]) -> Result { Ok(InferenceFeedback::Continue) } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 37861174..27428d73 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -333,6 +333,37 @@ impl InferenceSession { Ok(()) } + /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. + pub fn delete_tokens( + &mut self, + model: &dyn Model, + num: usize, + ) -> Result, DeleteError> { + if !model.supports_delete() { + return Err(DeleteError::UnsupportedArchitecture); + } + + if num >= self.n_past { + return Err(DeleteError::NotEnoughTokens); + } + + // Remove the tokens from self.tokens. + let token_start = self.n_past - num; + let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); + + // Remove the corresponding chars from decoded + let mut decoded_start = self.decoded_tokens.len(); + for id in &deleted_tokens { + decoded_start -= model.tokenizer().token(*id as usize).len(); + } + self.decoded_tokens.truncate(decoded_start); + + // Decrement the n_past tokens counter. + self.n_past -= num; + + Ok(deleted_tokens) + } + /// Infer the next token for this session. pub fn infer_next_token( &mut self, @@ -637,6 +668,18 @@ pub enum InferenceError { UserCallback(Box), } +#[derive(Error, Debug)] +/// Errors encountered during the snapshot process. +pub enum DeleteError { + /// Tried deleting more tokens than were available + #[error("tried deleting more tokens than were available")] + NotEnoughTokens, + + /// Model architecture does not support delete + #[error("model architecture does not support deletes")] + UnsupportedArchitecture, +} + #[derive(Error, Debug)] /// Errors encountered during the snapshot process. pub enum SnapshotError { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index d40a9077..479aba2d 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -23,9 +23,9 @@ pub use ggml; pub use ggml::Type as ElementType; pub use inference_session::{ - feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest, - InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError, + feed_prompt_callback, DeleteError, GraphOutputs, InferenceError, InferenceFeedback, + InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, + InferenceSnapshot, InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError, }; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 2de02828..90c86d58 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -86,6 +86,12 @@ pub trait KnownModel: Send + Sync { /// Get the list of regexes to use to determine if a tensor in this model should not be quantized. fn skip_quantize_tensors() -> Vec; + + /// Returns whether the model supports deleting tokens. + fn supports_delete(&self) -> bool { + // Assume we can't delete unless otherwise specified + false + } } /// A type-erased model to allow for interacting with a model without knowing @@ -118,6 +124,9 @@ pub trait Model: Send + Sync { /// Get the end of text/end of string token ID. This value is defined by model implementers. fn eot_token_id(&self) -> TokenId; + + /// Returns whether the model supports deleting tokens. + fn supports_delete(&self) -> bool; } impl> Model for M { fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { @@ -149,6 +158,10 @@ impl> Model for M { fn eot_token_id(&self) -> TokenId { KnownModel::eot_token_id(self) } + + fn supports_delete(&self) -> bool { + KnownModel::supports_delete(self) + } } /// Implemented by model hyperparameters for interacting with hyperparameters diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 2be90739..2cc22609 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -78,13 +78,13 @@ use std::{ // This is the "user-facing" API, and GGML may not always be our backend. pub use llm_base::{ feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout, - quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters, - InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, - InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, - InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, - ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, - TokenizerSource, + quantize, samplers, DeleteError, ElementType, FileType, FileTypeFormat, FormatMagic, + Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, + InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, + InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, + Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, + QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, + TokenizationError, Tokenizer, TokenizerSource, }; use serde::Serialize; diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 4e9aa192..5ef63b79 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -396,6 +396,10 @@ impl KnownModel for Bloom { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 195f876a..3d6cbcd2 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -318,6 +318,10 @@ impl KnownModel for GptJ { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 5339b901..641cfbcb 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -364,6 +364,10 @@ impl KnownModel for GptNeoX { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 6e7d4b11..0ea6661e 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -348,6 +348,10 @@ impl KnownModel for Llama { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// LLaMA [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 18991adf..757a372f 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -298,6 +298,10 @@ impl KnownModel for Mpt { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// MPT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))