diff --git a/binaries/llm-test/configs/bloom.json b/binaries/llm-test/configs/bloom.json index f8e0ad18..5383386d 100644 --- a/binaries/llm-test/configs/bloom.json +++ b/binaries/llm-test/configs/bloom.json @@ -9,6 +9,15 @@ "output_disabled": "When a llama rides a crab, ,.-\n\n/? ', , ; A;A = (b),d e orm\n“t” + “p。n unus et les el duetant alle that are by no ... ”\n( ? ) – ‘?\n!!\n«…..’,\nS.\n\n‘l」之 attergoir à dit-on pas .. 。。 ..\n– La leçon se confond quelquefois con ce qui es vée par occident .\n( 2 ) .\nLa protestation del paysan mécontent regardait pendre eussent mœurs faillite forteresse rivières lieues forteressemelés inquiétudes crackdown brawl slaughter massacresokea .\n» » … « …\n. . . \" \" ….", "maximum_token_count": 128 } + }, + { + "Tokens": { + "input": "Rustformers is", + "output": 15 + } + }, + { + "Delete": {} } ] -} \ No newline at end of file +} diff --git a/binaries/llm-test/configs/gptj.json b/binaries/llm-test/configs/gptj.json index 8eb832c3..50966748 100644 --- a/binaries/llm-test/configs/gptj.json +++ b/binaries/llm-test/configs/gptj.json @@ -9,6 +9,15 @@ "output_disabled": "\"When a llama rides a crab, \nit's not the same as when an elephant does it.\" - John Steinbeck, East of Eden.\n\n \"The best way to predict your future is by looking at history.\"- Robert Kiyosaki (author). Rich Dad Poor dad : what 10 rules for success really mean and how you can apply them in life! The rich dads guidebook on personal finance: How To Become A Millionaire In Less Than 5 years! http://www..richdadpoordaddyguidebooksalexanderkimballblogcom/the_bestwaytopredictyourfutureislookingathistory/. You will learn about money management", "maximum_token_count": 128 } + }, + { + "Tokens": { + "input": "Rustformers is", + "output": 257 + } + }, + { + "Delete": {} } ] -} \ No newline at end of file +} diff --git a/binaries/llm-test/configs/gptneox.json b/binaries/llm-test/configs/gptneox.json index 3aab6a99..c8cce4d9 100644 --- a/binaries/llm-test/configs/gptneox.json +++ b/binaries/llm-test/configs/gptneox.json @@ -9,6 +9,15 @@ "output_disabled": "<|padding|>When a llama rides a crab, \n“The Greatest Show on Earth” is the title of an 1875 book by Phineas Taylor Barnum, who founded and operated The circus. He was born in Bethel Connecticut to Meshack (Meshake) Bowman Jr., from New York City; his mother’s name has not been recorded but she may have had some Native American ancestry as well.[2] His father died when he[3][4], at age three,[5]: 9–10 (p1), 11-12​—was left with relatives until they could find him work or send for them back home where there", "maximum_token_count": 128 } + }, + { + "Tokens": { + "input": "Rustformers is", + "output": 247 + } + }, + { + "Delete": {} } ] -} \ No newline at end of file +} diff --git a/binaries/llm-test/configs/llama.json b/binaries/llm-test/configs/llama.json index c7e485e3..9bd6094a 100644 --- a/binaries/llm-test/configs/llama.json +++ b/binaries/llm-test/configs/llama.json @@ -9,6 +9,15 @@ "output": "When a llama rides a crab, 10-year olds are the ones who get to eat.\nTheir parents have been told that they will be eating for another year or two before their children can enjoy it again – and then only if there is enough food left over from Christmas dinner!", "maximum_token_count": 128 } + }, + { + "Tokens": { + "input": "Rustformers is", + "output": 260 + } + }, + { + "Delete": {} } ] -} \ No newline at end of file +} diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 6dd316ee..57a8bc89 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -9,6 +9,15 @@ "output": "When a llama rides a crab,  the llama is called the \"crab rider\".\nThe crabs are very popular in South America, especially Brazil. They have been used as transportation for many years and they can carry up to five people at once!", "maximum_token_count": 128 } + }, + { + "Tokens": { + "input": "Rustformers is", + "output": 247 + } + }, + { + "Delete": {} } ] -} \ No newline at end of file +} diff --git a/binaries/llm-test/src/common.rs b/binaries/llm-test/src/common.rs new file mode 100644 index 00000000..4c858820 --- /dev/null +++ b/binaries/llm-test/src/common.rs @@ -0,0 +1,30 @@ +//! Tests that are run on every model, regardless of config. + +pub(super) fn can_send(model: M) -> anyhow::Result { + let model = std::thread::spawn(move || model) + .join() + .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); + + log::info!("`can_send` test passed!"); + + model +} + +pub(super) fn can_roundtrip_hyperparameters( + model: &M, +) -> anyhow::Result<()> { + fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { + let mut data = vec![]; + hyperparameters.write_ggml(&mut data)?; + let new_hyperparameters = + ::read_ggml(&mut std::io::Cursor::new(data))?; + + assert_eq!(hyperparameters, &new_hyperparameters); + + log::info!("`can_roundtrip_hyperparameters` test passed!"); + + Ok(()) + } + + test_hyperparameters(model.hyperparameters()) +} diff --git a/binaries/llm-test/src/delete.rs b/binaries/llm-test/src/delete.rs new file mode 100644 index 00000000..9ddbe7a8 --- /dev/null +++ b/binaries/llm-test/src/delete.rs @@ -0,0 +1,95 @@ +//! Tests the model's token manipulation APIs: +//! +//! * [llm::InferenceSession::feed_prompt()] +//! +//! See [crate::TestCase::Tokens]. + +use std::convert::Infallible; + +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; +use serde::Serialize; + +use crate::{TestCaseReport, TestCaseReportMeta}; + +/// Tests that models can delete tokens without changing the model's behavior. +pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { + let report = DeleteReport::default(); + let mut session = model.start_session(Default::default()); + let mut output = OutputRequest { + all_logits: Some(vec![]), + ..Default::default() + }; + + // Feed some tokens + if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + + // Add token and get the logits + if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(original_logits) = output.all_logits.clone() else { + return report.failure("Model did not return logits."); + }; + + // Rewind, then re-add. Verify logits are the same. + if let Err(err) = session.rewind(model, 1) { + return report.failure(&err.to_string()); + } + if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { + return report.failure(&err.to_string()); + } + let Some(redone_logits) = output.all_logits.clone() else { + return report.failure("Second run of model did not return logits."); + }; + + // Compare the logits + for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() { + if original > redone + f32::EPSILON || original < redone - f32::EPSILON { + return report.failure(&format!( + "Expected logits to be the same after delete, but differed at {idx}, \ + expected {original}, but was {redone}." + )); + } + } + + log::info!("`can_delete` test passed!"); + report.success() +} + +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, always_continue) +} + +fn always_continue(_: &[u8]) -> Result { + Ok(InferenceFeedback::Continue) +} + +#[derive(Serialize, Default)] +pub struct DeleteReport { + output: usize, +} + +impl DeleteReport { + fn failure(self, msg: &str) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Error { + error: msg.to_owned(), + }, + report: crate::TestCaseReportInner::Delete(self), + } + } + + fn success(self) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Success, + report: crate::TestCaseReportInner::Delete(self), + } + } +} diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs new file mode 100644 index 00000000..ec803cdd --- /dev/null +++ b/binaries/llm-test/src/inference.rs @@ -0,0 +1,116 @@ +//! Tests the model's inference APIs. +//! +//! See [crate::TestCase::Inference]. + +use std::{convert::Infallible, sync::Arc}; + +use llm::InferenceStats; + +use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta}; + +pub(crate) fn can_infer( + model: &dyn llm::Model, + model_config: &ModelConfig, + input: &str, + expected_output: Option<&str>, + maximum_token_count: usize, +) -> anyhow::Result { + let mut session = model.start_session(Default::default()); + let (actual_output, res) = run_inference( + model, + model_config, + &mut session, + input, + maximum_token_count, + ); + + // Process the results + Ok(TestCaseReport { + meta: match &res { + Ok(_) => match expected_output { + Some(expected_output) => { + if expected_output == actual_output { + log::info!("`can_infer` test passed!"); + TestCaseReportMeta::Success + } else { + TestCaseReportMeta::Error { + error: "The output did not match the expected output.".to_string(), + } + } + } + None => { + log::info!("`can_infer` test passed (no expected output)!"); + TestCaseReportMeta::Success + } + }, + Err(err) => TestCaseReportMeta::Error { + error: err.to_string(), + }, + }, + report: TestCaseReportInner::Inference { + input: input.into(), + expect_output: expected_output.map(|s| s.to_string()), + actual_output, + inference_stats: res.ok(), + }, + }) +} + +fn run_inference( + model: &dyn llm::Model, + model_config: &ModelConfig, + session: &mut llm::InferenceSession, + input: &str, + maximum_token_count: usize, +) -> (String, Result) { + let mut actual_output: String = String::new(); + let res = session.infer::( + model, + &mut rand::rngs::mock::StepRng::new(0, 1), + &llm::InferenceRequest { + prompt: input.into(), + parameters: &llm::InferenceParameters { + n_threads: model_config.threads, + n_batch: 1, + sampler: Arc::new(DeterministicSampler), + }, + play_back_previous_tokens: false, + maximum_token_count: Some(maximum_token_count), + }, + &mut Default::default(), + |r| match r { + llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { + actual_output += &t; + Ok(llm::InferenceFeedback::Continue) + } + _ => Ok(llm::InferenceFeedback::Continue), + }, + ); + + (actual_output, res) +} + +#[derive(Debug)] +struct DeterministicSampler; +impl llm::Sampler for DeterministicSampler { + fn sample( + &self, + previous_tokens: &[llm::TokenId], + logits: &[f32], + _rng: &mut dyn rand::RngCore, + ) -> llm::TokenId { + // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` + // at all + let mut logits = logits.to_vec(); + for &token in previous_tokens { + logits[token as usize] = f32::NEG_INFINITY; + } + + logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0 as llm::TokenId + } +} diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index aa259779..b1bc9b07 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -1,3 +1,10 @@ +//! Test runner for all LLMs. + +mod common; +mod delete; +mod inference; +mod tokens; + use anyhow::Context; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; @@ -7,13 +14,11 @@ use serde::{Deserialize, Serialize}; use std::{ cmp::min, collections::HashMap, - convert::Infallible, env, fs::{self, File}, io::Write, path::{Path, PathBuf}, str::FromStr, - sync::Arc, time::Instant, }; @@ -120,6 +125,11 @@ enum TestCase { output: Option, maximum_token_count: usize, }, + Tokens { + input: String, + output: usize, + }, + Delete {}, } #[derive(Serialize)] @@ -142,13 +152,15 @@ enum TestCaseReportMeta { } #[derive(Serialize)] -enum TestCaseReportInner { +pub enum TestCaseReportInner { Inference { input: String, expect_output: Option, actual_output: String, inference_stats: Option, }, + Tokens(tokens::TokensReport), + Delete(delete::DeleteReport), } async fn test_model( @@ -240,10 +252,10 @@ async fn test_model( // // Confirm that the model can be sent to a thread, then sent back - let model = tests::can_send(model)?; + let model = common::can_send(model)?; // Confirm that the hyperparameters can be roundtripped - tests::can_roundtrip_hyperparameters(&model)?; + common::can_roundtrip_hyperparameters(&model)?; // @@ -259,13 +271,19 @@ async fn test_model( input, output, maximum_token_count, - } => test_case_reports.push(tests::can_infer( + } => test_case_reports.push(inference::can_infer( &model, model_config, input, output.as_deref(), *maximum_token_count, )?), + TestCase::Tokens { input, output } => { + test_case_reports.push(tokens::can_feed(&model, input, *output)); + } + TestCase::Delete {} => { + test_case_reports.push(delete::can_delete(&model)); + } } } let first_error: Option = @@ -320,148 +338,6 @@ fn write_report( Ok(()) } -mod tests { - use super::*; - - pub(super) fn can_send(model: M) -> anyhow::Result { - let model = std::thread::spawn(move || model) - .join() - .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); - - log::info!("`can_send` test passed!"); - - model - } - - pub(super) fn can_roundtrip_hyperparameters( - model: &M, - ) -> anyhow::Result<()> { - fn test_hyperparameters( - hyperparameters: &M, - ) -> anyhow::Result<()> { - let mut data = vec![]; - hyperparameters.write_ggml(&mut data)?; - let new_hyperparameters = - ::read_ggml(&mut std::io::Cursor::new(data))?; - - assert_eq!(hyperparameters, &new_hyperparameters); - - log::info!("`can_roundtrip_hyperparameters` test passed!"); - - Ok(()) - } - - test_hyperparameters(model.hyperparameters()) - } - - pub(super) fn can_infer( - model: &dyn llm::Model, - model_config: &ModelConfig, - input: &str, - expected_output: Option<&str>, - maximum_token_count: usize, - ) -> anyhow::Result { - let mut session = model.start_session(Default::default()); - let (actual_output, res) = run_inference( - model, - model_config, - &mut session, - input, - maximum_token_count, - ); - - // Process the results - Ok(TestCaseReport { - meta: match &res { - Ok(_) => match expected_output { - Some(expected_output) => { - if expected_output == actual_output { - log::info!("`can_infer` test passed!"); - TestCaseReportMeta::Success - } else { - TestCaseReportMeta::Error { - error: "The output did not match the expected output.".to_string(), - } - } - } - None => { - log::info!("`can_infer` test passed (no expected output)!"); - TestCaseReportMeta::Success - } - }, - Err(err) => TestCaseReportMeta::Error { - error: err.to_string(), - }, - }, - report: TestCaseReportInner::Inference { - input: input.into(), - expect_output: expected_output.map(|s| s.to_string()), - actual_output, - inference_stats: res.ok(), - }, - }) - } -} - -fn run_inference( - model: &dyn llm::Model, - model_config: &ModelConfig, - session: &mut llm::InferenceSession, - input: &str, - maximum_token_count: usize, -) -> (String, Result) { - let mut actual_output: String = String::new(); - let res = session.infer::( - model, - &mut rand::rngs::mock::StepRng::new(0, 1), - &llm::InferenceRequest { - prompt: input.into(), - parameters: &llm::InferenceParameters { - n_threads: model_config.threads, - n_batch: 1, - sampler: Arc::new(DeterministicSampler), - }, - play_back_previous_tokens: false, - maximum_token_count: Some(maximum_token_count), - }, - &mut Default::default(), - |r| match r { - llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { - actual_output += &t; - Ok(llm::InferenceFeedback::Continue) - } - _ => Ok(llm::InferenceFeedback::Continue), - }, - ); - - (actual_output, res) -} - -#[derive(Debug)] -struct DeterministicSampler; -impl llm::Sampler for DeterministicSampler { - fn sample( - &self, - previous_tokens: &[llm::TokenId], - logits: &[f32], - _rng: &mut dyn rand::RngCore, - ) -> llm::TokenId { - // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` - // at all - let mut logits = logits.to_vec(); - for &token in previous_tokens { - logits[token as usize] = f32::NEG_INFINITY; - } - - logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .unwrap() - .0 as llm::TokenId - } -} - async fn download_file(url: &str, local_path: &Path) -> anyhow::Result<()> { if local_path.exists() { return Ok(()); diff --git a/binaries/llm-test/src/tokens.rs b/binaries/llm-test/src/tokens.rs new file mode 100644 index 00000000..260546b8 --- /dev/null +++ b/binaries/llm-test/src/tokens.rs @@ -0,0 +1,98 @@ +//! Tests the model's token manipulation APIs: +//! +//! * [llm::InferenceSession::feed_prompt()] +//! +//! See [crate::TestCase::Tokens]. + +use std::convert::Infallible; + +use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; +use serde::Serialize; + +use crate::{TestCaseReport, TestCaseReportMeta}; + +/// Tests that the model performs as expected when feeding tokens +pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport { + let mut report = TokensReport::default(); + let mut session = model.start_session(Default::default()); + let mut output = OutputRequest { + all_logits: Some(vec![]), + ..Default::default() + }; + + if let Err(err) = feed_prompt(input, &mut session, model, &mut output) { + return report.failure(&err.to_string()); + }; + + let top_token; + match output.all_logits { + Some(logits) => { + let start = logits.len() - model.tokenizer().len(); + let mut iter = logits[start..].iter().enumerate(); + let Some((mut max_idx, mut max)) = iter.next() else { + return report.failure("Could not find any logits for last token."); + }; + for (idx, score) in iter { + if score > max { + max = score; + max_idx = idx; + } + } + top_token = max_idx; + } + None => return report.failure("Model did not output any logits."), + } + + report.output = top_token; + + if top_token != expected_output { + let tokenizer = model.tokenizer(); + let top_token_str = String::from_utf8_lossy(&tokenizer.token(top_token)).to_string(); + let expected_str = String::from_utf8_lossy(&tokenizer.token(expected_output)).to_string(); + return report.failure(&format!( + "Expected top token to be {expected_output} ({expected_str}), \ + but was {top_token} ({top_token_str})" + )); + } + + log::info!("`can_feed` test passed!"); + report.success() +} + +fn feed_prompt( + prompt: &str, + session: &mut InferenceSession, + model: &impl Model, + output: &mut OutputRequest, +) -> Result<(), llm::InferenceError> { + session.feed_prompt(model, &Default::default(), prompt, output, |x| { + always_continue(x) + }) +} + +fn always_continue(_: &[u8]) -> Result { + Ok(InferenceFeedback::Continue) +} + +#[derive(Serialize, Default)] +pub struct TokensReport { + output: usize, +} + +impl TokensReport { + fn failure(self, msg: &str) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Error { + error: msg.to_owned(), + }, + report: crate::TestCaseReportInner::Tokens(self), + } + } + + fn success(self) -> TestCaseReport { + TestCaseReport { + meta: TestCaseReportMeta::Success, + report: crate::TestCaseReportInner::Tokens(self), + } + } +} diff --git a/crates/ggml/sys/build.rs b/crates/ggml/sys/build.rs index 084f09de..3d86d0c0 100644 --- a/crates/ggml/sys/build.rs +++ b/crates/ggml/sys/build.rs @@ -5,6 +5,8 @@ use std::path::{Path, PathBuf}; // the host and target are the same. If they are not, it will turn off auto-feature-detection, // and you will need to manually specify target features through target-features. fn main() { + verify_state(); + println!("cargo:rerun-if-changed=llama-cpp"); let mut builder = cc::Build::new(); @@ -104,6 +106,14 @@ fn main() { } } +/// Verify the state of the repo to catch common newbie mistakes. +fn verify_state() { + assert!( + Path::new("llama-cpp/ggml.c").exists(), + "Could not find llama-cpp/ggml.c. Try running `git submodule update --init`" + ); +} + fn cfg_cublas() -> bool { !cfg!(target_os = "macos") && cfg!(feature = "cublas") } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 37861174..2c4fcf6e 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -333,6 +333,33 @@ impl InferenceSession { Ok(()) } + /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. + pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result, RewindError> { + if !model.supports_rewind() { + return Err(RewindError::UnsupportedArchitecture); + } + + if num >= self.n_past { + return Err(RewindError::NotEnoughTokens); + } + + // Remove the tokens from self.tokens. + let token_start = self.n_past - num; + let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); + + // Remove the corresponding chars from decoded + let mut decoded_start = self.decoded_tokens.len(); + for id in &deleted_tokens { + decoded_start -= model.tokenizer().token(*id as usize).len(); + } + self.decoded_tokens.truncate(decoded_start); + + // Decrement the n_past tokens counter. + self.n_past -= num; + + Ok(deleted_tokens) + } + /// Infer the next token for this session. pub fn infer_next_token( &mut self, @@ -637,6 +664,18 @@ pub enum InferenceError { UserCallback(Box), } +#[derive(Error, Debug)] +/// Errors encountered during the snapshot process. +pub enum RewindError { + /// Tried deleting more tokens than were available + #[error("tried deleting more tokens than were available")] + NotEnoughTokens, + + /// Model architecture does not support delete + #[error("model architecture does not support deletes")] + UnsupportedArchitecture, +} + #[derive(Error, Debug)] /// Errors encountered during the snapshot process. pub enum SnapshotError { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index d40a9077..1ec18d1c 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -25,7 +25,7 @@ pub use ggml::Type as ElementType; pub use inference_session::{ feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError, + InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError, }; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 2de02828..b2b49a0a 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -86,6 +86,12 @@ pub trait KnownModel: Send + Sync { /// Get the list of regexes to use to determine if a tensor in this model should not be quantized. fn skip_quantize_tensors() -> Vec; + + /// Returns whether the model supports deleting tokens. + fn supports_rewind(&self) -> bool { + // Assume we can't delete unless otherwise specified + false + } } /// A type-erased model to allow for interacting with a model without knowing @@ -118,6 +124,9 @@ pub trait Model: Send + Sync { /// Get the end of text/end of string token ID. This value is defined by model implementers. fn eot_token_id(&self) -> TokenId; + + /// Returns whether the model supports deleting tokens. + fn supports_rewind(&self) -> bool; } impl> Model for M { fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { @@ -149,6 +158,10 @@ impl> Model for M { fn eot_token_id(&self) -> TokenId { KnownModel::eot_token_id(self) } + + fn supports_rewind(&self) -> bool { + KnownModel::supports_rewind(self) + } } /// Implemented by model hyperparameters for interacting with hyperparameters diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 2be90739..35692951 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -83,8 +83,8 @@ pub use llm_base::{ InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, - TokenizerSource, + RewindError, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, + Tokenizer, TokenizerSource, }; use serde::Serialize; diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 4e9aa192..0897a210 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -396,6 +396,10 @@ impl KnownModel for Bloom { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_rewind(&self) -> bool { + true + } } /// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 195f876a..42a039c6 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -318,6 +318,10 @@ impl KnownModel for GptJ { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_rewind(&self) -> bool { + true + } } /// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 5339b901..e033bdbe 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -364,6 +364,10 @@ impl KnownModel for GptNeoX { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_rewind(&self) -> bool { + true + } } /// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 6e7d4b11..94585218 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -348,6 +348,10 @@ impl KnownModel for Llama { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_rewind(&self) -> bool { + true + } } /// LLaMA [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 18991adf..203d779d 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -298,6 +298,10 @@ impl KnownModel for Mpt { fn skip_quantize_tensors() -> Vec { vec![] } + + fn supports_rewind(&self) -> bool { + true + } } /// MPT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))