diff --git a/binaries/llm-test/src/delete.rs b/binaries/llm-test/src/delete.rs index 2c609cb5..9ddbe7a8 100644 --- a/binaries/llm-test/src/delete.rs +++ b/binaries/llm-test/src/delete.rs @@ -11,9 +11,6 @@ use serde::Serialize; use crate::{TestCaseReport, TestCaseReportMeta}; -/// Error tolerance for the float comparisons. -const TOLERANCE: f32 = 1e-7; - /// Tests that models can delete tokens without changing the model's behavior. pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { let report = DeleteReport::default(); @@ -36,8 +33,8 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { return report.failure("Model did not return logits."); }; - // Delete, then re-add. Verify logits are the same. - if let Err(err) = session.delete_tokens(model, 1) { + // Rewind, then re-add. Verify logits are the same. + if let Err(err) = session.rewind(model, 1) { return report.failure(&err.to_string()); } if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { @@ -49,7 +46,7 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { // Compare the logits for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() { - if original > redone + TOLERANCE || original < redone - TOLERANCE { + if original > redone + f32::EPSILON || original < redone - f32::EPSILON { return report.failure(&format!( "Expected logits to be the same after delete, but differed at {idx}, \ expected {original}, but was {redone}." @@ -57,7 +54,7 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { } } - log::info!("`can_delete` test passed (no expected output)!"); + log::info!("`can_delete` test passed!"); report.success() } @@ -67,9 +64,7 @@ fn feed_prompt( model: &impl Model, output: &mut OutputRequest, ) -> Result<(), llm::InferenceError> { - session.feed_prompt(model, &Default::default(), prompt, output, |x| { - always_continue(x) - }) + session.feed_prompt(model, &Default::default(), prompt, output, always_continue) } fn always_continue(_: &[u8]) -> Result { diff --git a/binaries/llm-test/src/tokens.rs b/binaries/llm-test/src/tokens.rs index a2a490b6..260546b8 100644 --- a/binaries/llm-test/src/tokens.rs +++ b/binaries/llm-test/src/tokens.rs @@ -55,7 +55,7 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) )); } - log::info!("`can_feed` test passed (no expected output)!"); + log::info!("`can_feed` test passed!"); report.success() } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 27428d73..2c4fcf6e 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -334,17 +334,13 @@ impl InferenceSession { } /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. - pub fn delete_tokens( - &mut self, - model: &dyn Model, - num: usize, - ) -> Result, DeleteError> { - if !model.supports_delete() { - return Err(DeleteError::UnsupportedArchitecture); + pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result, RewindError> { + if !model.supports_rewind() { + return Err(RewindError::UnsupportedArchitecture); } if num >= self.n_past { - return Err(DeleteError::NotEnoughTokens); + return Err(RewindError::NotEnoughTokens); } // Remove the tokens from self.tokens. @@ -670,7 +666,7 @@ pub enum InferenceError { #[derive(Error, Debug)] /// Errors encountered during the snapshot process. -pub enum DeleteError { +pub enum RewindError { /// Tried deleting more tokens than were available #[error("tried deleting more tokens than were available")] NotEnoughTokens, diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 479aba2d..1ec18d1c 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -23,9 +23,9 @@ pub use ggml; pub use ggml::Type as ElementType; pub use inference_session::{ - feed_prompt_callback, DeleteError, GraphOutputs, InferenceError, InferenceFeedback, - InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, - InferenceSnapshot, InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError, + feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest, + InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, + InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError, }; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 90c86d58..b2b49a0a 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -88,7 +88,7 @@ pub trait KnownModel: Send + Sync { fn skip_quantize_tensors() -> Vec; /// Returns whether the model supports deleting tokens. - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { // Assume we can't delete unless otherwise specified false } @@ -126,7 +126,7 @@ pub trait Model: Send + Sync { fn eot_token_id(&self) -> TokenId; /// Returns whether the model supports deleting tokens. - fn supports_delete(&self) -> bool; + fn supports_rewind(&self) -> bool; } impl> Model for M { fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { @@ -159,8 +159,8 @@ impl> Model for M { KnownModel::eot_token_id(self) } - fn supports_delete(&self) -> bool { - KnownModel::supports_delete(self) + fn supports_rewind(&self) -> bool { + KnownModel::supports_rewind(self) } } diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 2cc22609..35692951 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -78,13 +78,13 @@ use std::{ // This is the "user-facing" API, and GGML may not always be our backend. pub use llm_base::{ feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout, - quantize, samplers, DeleteError, ElementType, FileType, FileTypeFormat, FormatMagic, - Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, - InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, - Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, - QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, - TokenizationError, Tokenizer, TokenizerSource, + quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters, + InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, + InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, + InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, + ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, + RewindError, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, + Tokenizer, TokenizerSource, }; use serde::Serialize; diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 5ef63b79..0897a210 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -397,7 +397,7 @@ impl KnownModel for Bloom { vec![] } - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { true } } diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 3d6cbcd2..42a039c6 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -319,7 +319,7 @@ impl KnownModel for GptJ { vec![] } - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { true } } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 641cfbcb..e033bdbe 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -365,7 +365,7 @@ impl KnownModel for GptNeoX { vec![] } - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { true } } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 0ea6661e..94585218 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -349,7 +349,7 @@ impl KnownModel for Llama { vec![] } - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { true } } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 757a372f..203d779d 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -299,7 +299,7 @@ impl KnownModel for Mpt { vec![] } - fn supports_delete(&self) -> bool { + fn supports_rewind(&self) -> bool { true } }