From fe5c112e5fb75dcdf9b289b18f85659577949032 Mon Sep 17 00:00:00 2001 From: Steven Weiss Date: Wed, 5 Jul 2023 19:06:39 -0700 Subject: [PATCH] Add delete tokens test and impl. Note that llama.json fails the tests, so it's likely it doesn't support it. I may investigate further, though. --- binaries/llm-test/configs/bloom.json | 3 + binaries/llm-test/configs/gptj.json | 3 + binaries/llm-test/configs/gptneox.json | 3 + binaries/llm-test/configs/mpt.json | 3 + binaries/llm-test/src/delete.rs | 100 +++++++++++++++++++++++ binaries/llm-test/src/main.rs | 6 ++ binaries/llm-test/src/tokens.rs | 23 ++++-- crates/llm-base/src/inference_session.rs | 41 ++++++++++ crates/llm-base/src/model/mod.rs | 13 +++ crates/models/bloom/src/lib.rs | 4 + crates/models/gptj/src/lib.rs | 4 + crates/models/gptneox/src/lib.rs | 4 + crates/models/mpt/src/lib.rs | 4 + 13 files changed, 202 insertions(+), 9 deletions(-) create mode 100644 binaries/llm-test/src/delete.rs diff --git a/binaries/llm-test/configs/bloom.json b/binaries/llm-test/configs/bloom.json index b8c4bc2a..5383386d 100644 --- a/binaries/llm-test/configs/bloom.json +++ b/binaries/llm-test/configs/bloom.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 15 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/gptj.json b/binaries/llm-test/configs/gptj.json index 85493b5c..50966748 100644 --- a/binaries/llm-test/configs/gptj.json +++ b/binaries/llm-test/configs/gptj.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 257 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/gptneox.json b/binaries/llm-test/configs/gptneox.json index 3e6b84cb..c8cce4d9 100644 --- a/binaries/llm-test/configs/gptneox.json +++ b/binaries/llm-test/configs/gptneox.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 247 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 7142c143..57a8bc89 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -15,6 +15,9 @@ "input": "Rustformers is", "output": 247 } + }, + { + "Delete": {} } ] } diff --git a/binaries/llm-test/src/delete.rs b/binaries/llm-test/src/delete.rs new file mode 100644 index 00000000..404d7aba --- /dev/null +++ b/binaries/llm-test/src/delete.rs @@ -0,0 +1,100 @@ +//! Tests the model's token manipulation APIs: +//! +//! * [llm::InferenceSession::feed_prompt()] +//! +//! See [crate::TestCase::Tokens]. + +use std::convert::Infallible; + +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; +use serde::Serialize; + +use crate::{TestCaseReport, TestCaseReportMeta}; + +/// Error tolerance for the float comparisons. +const TOLERANCE: f32 = 1e-7; + +/// Tests that models can delete tokens without changing the model's behavior. +pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { + let report = DeleteReport::default(); + let mut session = model.start_session(Default::default()); + let mut output = OutputRequest { + all_logits: Some(vec![]), + ..Default::default() + }; + + // Feed some tokens + if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + + // Add token and get the logits + if let Err(err) = feed_prompt(" crab", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(original_logits) = output.all_logits.clone() else { + return report.failure("Model did not return logits."); + }; + + // Delete, then re-add. Verify logits are the same. + if let Err(err) = session.delete_tokens(model, 1) { + return report.failure(&err.to_string()); + } + if let Err(err) = feed_prompt(" crab", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(redone_logits) = output.all_logits.clone() else { + return report.failure("Second run of model did not return logits."); + }; + + // Compare the logits + for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() { + if original > redone + TOLERANCE || original < redone - TOLERANCE { + return report.failure(&format!( + "Expected logits to be the same after delete, but differed at {idx}, \ + expected {original}, but was {redone}." + )); + } + } + + log::info!("`can_delete` test passed (no expected output)!"); + report.success() +} + +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, |x| { + always_continue(x) + }) +} + +fn always_continue(_: &[u8]) -> Result { + Ok(InferenceFeedback::Continue) +} + +#[derive(Serialize, Default)] +pub struct DeleteReport { + output: usize, +} + +impl DeleteReport { + fn failure(self, msg: &str) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Error { + error: msg.to_owned(), + }, + report: crate::TestCaseReportInner::Delete(self), + } + } + + fn success(self) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Success, + report: crate::TestCaseReportInner::Delete(self), + } + } +} diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index b14ed78a..b1bc9b07 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -1,6 +1,7 @@ //! Test runner for all LLMs. mod common; +mod delete; mod inference; mod tokens; @@ -128,6 +129,7 @@ enum TestCase { input: String, output: usize, }, + Delete {}, } #[derive(Serialize)] @@ -158,6 +160,7 @@ pub enum TestCaseReportInner { inference_stats: Option, }, Tokens(tokens::TokensReport), + Delete(delete::DeleteReport), } async fn test_model( @@ -278,6 +281,9 @@ async fn test_model( TestCase::Tokens { input, output } => { test_case_reports.push(tokens::can_feed(&model, input, *output)); } + TestCase::Delete {} => { + test_case_reports.push(delete::can_delete(&model)); + } } } let first_error: Option = diff --git a/binaries/llm-test/src/tokens.rs b/binaries/llm-test/src/tokens.rs index 52f37019..a2a490b6 100644 --- a/binaries/llm-test/src/tokens.rs +++ b/binaries/llm-test/src/tokens.rs @@ -6,7 +6,7 @@ use std::convert::Infallible; -use llm::{InferenceFeedback, Model, OutputRequest}; +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; use serde::Serialize; use crate::{TestCaseReport, TestCaseReportMeta}; @@ -14,20 +14,13 @@ use crate::{TestCaseReport, TestCaseReportMeta}; /// Tests that the model performs as expected when feeding tokens pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport { let mut report = TokensReport::default(); - let mut session = model.start_session(Default::default()); let mut output = OutputRequest { all_logits: Some(vec![]), ..Default::default() }; - let feed_prompt = &mut |prompt: &str| { - session.feed_prompt(model, &Default::default(), prompt, &mut output, |x| { - always_continue(x) - }) - }; - - if let Err(err) = feed_prompt(input) { + if let Err(err) = feed_prompt(input, &mut session, model, &mut output) { return report.failure(&err.to_string()); }; @@ -62,9 +55,21 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) )); } + log::info!("`can_feed` test passed (no expected output)!"); report.success() } +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, |x| { + always_continue(x) + }) +} + fn always_continue(_: &[u8]) -> Result { Ok(InferenceFeedback::Continue) } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 37861174..2aa4754e 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -333,6 +333,37 @@ impl InferenceSession { Ok(()) } + /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. + pub fn delete_tokens( + &mut self, + model: &dyn Model, + num: usize, + ) -> Result, DeleteError> { + if !model.supports_delete() { + return Err(DeleteError::UnsupportedArchitecture); + } + + if num >= self.n_past { + return Err(DeleteError::NotEnoughTokens); + } + + // Remove the tokens from self.tokens. + let token_start = self.n_past - num; + let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); + + // Remove the corresponding chars from decoded + let mut decoded_start = self.decoded_tokens.len(); + for id in &deleted_tokens { + decoded_start -= model.tokenizer().token(*id as usize).len(); + } + self.decoded_tokens.truncate(decoded_start); + + // Decrement the n_past tokens counter. + self.n_past -= num; + + Ok(deleted_tokens) + } + /// Infer the next token for this session. pub fn infer_next_token( &mut self, @@ -637,6 +668,16 @@ pub enum InferenceError { UserCallback(Box), } +#[derive(Error, Debug)] +/// Errors encountered during the snapshot process. +pub enum DeleteError { + #[error("tried deleting more tokens than were available")] + NotEnoughTokens, + + #[error("model architecture does not support deletes")] + UnsupportedArchitecture, +} + #[derive(Error, Debug)] /// Errors encountered during the snapshot process. pub enum SnapshotError { diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 2de02828..90c86d58 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -86,6 +86,12 @@ pub trait KnownModel: Send + Sync { /// Get the list of regexes to use to determine if a tensor in this model should not be quantized. fn skip_quantize_tensors() -> Vec; + + /// Returns whether the model supports deleting tokens. + fn supports_delete(&self) -> bool { + // Assume we can't delete unless otherwise specified + false + } } /// A type-erased model to allow for interacting with a model without knowing @@ -118,6 +124,9 @@ pub trait Model: Send + Sync { /// Get the end of text/end of string token ID. This value is defined by model implementers. fn eot_token_id(&self) -> TokenId; + + /// Returns whether the model supports deleting tokens. + fn supports_delete(&self) -> bool; } impl> Model for M { fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { @@ -149,6 +158,10 @@ impl> Model for M { fn eot_token_id(&self) -> TokenId { KnownModel::eot_token_id(self) } + + fn supports_delete(&self) -> bool { + KnownModel::supports_delete(self) + } } /// Implemented by model hyperparameters for interacting with hyperparameters diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 4e9aa192..5ef63b79 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -396,6 +396,10 @@ impl KnownModel for Bloom { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 195f876a..3d6cbcd2 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -318,6 +318,10 @@ impl KnownModel for GptJ { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 5339b901..641cfbcb 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -364,6 +364,10 @@ impl KnownModel for GptNeoX { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 18991adf..757a372f 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -298,6 +298,10 @@ impl KnownModel for Mpt { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_delete(&self) -> bool { + true + } } /// MPT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))