Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
steventrouble committed Jul 9, 2023
1 parent 2e35b46 commit 2badcd9
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 39 deletions.
15 changes: 5 additions & 10 deletions binaries/llm-test/src/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ use serde::Serialize;

use crate::{TestCaseReport, TestCaseReportMeta};

/// Error tolerance for the float comparisons.
const TOLERANCE: f32 = 1e-7;

/// Tests that models can delete tokens without changing the model's behavior.
pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport {
let report = DeleteReport::default();
Expand All @@ -36,8 +33,8 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport {
return report.failure("Model did not return logits.");
};

// Delete, then re-add. Verify logits are the same.
if let Err(err) = session.delete_tokens(model, 1) {
// Rewind, then re-add. Verify logits are the same.
if let Err(err) = session.rewind(model, 1) {
return report.failure(&err.to_string());
}
if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) {
Expand All @@ -49,15 +46,15 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport {

// Compare the logits
for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() {
if original > redone + TOLERANCE || original < redone - TOLERANCE {
if original > redone + f32::EPSILON || original < redone - f32::EPSILON {
return report.failure(&format!(
"Expected logits to be the same after delete, but differed at {idx}, \
expected {original}, but was {redone}."
));
}
}

log::info!("`can_delete` test passed (no expected output)!");
log::info!("`can_delete` test passed!");
report.success()
}

Expand All @@ -67,9 +64,7 @@ fn feed_prompt(
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, |x| {
always_continue(x)
})
session.feed_prompt(model, &Default::default(), prompt, output, always_continue)
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-test/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize)
));
}

log::info!("`can_feed` test passed (no expected output)!");
log::info!("`can_feed` test passed!");
report.success()
}

Expand Down
14 changes: 5 additions & 9 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,13 @@ impl InferenceSession {
}

/// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`.
pub fn delete_tokens(
&mut self,
model: &dyn Model,
num: usize,
) -> Result<Vec<TokenId>, DeleteError> {
if !model.supports_delete() {
return Err(DeleteError::UnsupportedArchitecture);
pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result<Vec<TokenId>, RewindError> {
if !model.supports_rewind() {
return Err(RewindError::UnsupportedArchitecture);
}

if num >= self.n_past {
return Err(DeleteError::NotEnoughTokens);
return Err(RewindError::NotEnoughTokens);
}

// Remove the tokens from self.tokens.
Expand Down Expand Up @@ -670,7 +666,7 @@ pub enum InferenceError {

#[derive(Error, Debug)]
/// Errors encountered during the snapshot process.
pub enum DeleteError {
pub enum RewindError {
/// Tried deleting more tokens than were available
#[error("tried deleting more tokens than were available")]
NotEnoughTokens,
Expand Down
6 changes: 3 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
feed_prompt_callback, DeleteError, GraphOutputs, InferenceError, InferenceFeedback,
InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig,
InferenceSnapshot, InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError,
feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError,
};
pub use loader::{
load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic,
Expand Down
8 changes: 4 additions & 4 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub trait KnownModel: Send + Sync {
fn skip_quantize_tensors() -> Vec<Regex>;

/// Returns whether the model supports deleting tokens.
fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
// Assume we can't delete unless otherwise specified
false
}
Expand Down Expand Up @@ -126,7 +126,7 @@ pub trait Model: Send + Sync {
fn eot_token_id(&self) -> TokenId;

/// Returns whether the model supports deleting tokens.
fn supports_delete(&self) -> bool;
fn supports_rewind(&self) -> bool;
}
impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
Expand Down Expand Up @@ -159,8 +159,8 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
KnownModel::eot_token_id(self)
}

fn supports_delete(&self) -> bool {
KnownModel::supports_delete(self)
fn supports_rewind(&self) -> bool {
KnownModel::supports_rewind(self)
}
}

Expand Down
14 changes: 7 additions & 7 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ use std::{
// This is the "user-facing" API, and GGML may not always be our backend.
pub use llm_base::{
feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout,
quantize, samplers, DeleteError, ElementType, FileType, FileTypeFormat, FormatMagic,
Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress,
Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError,
QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer,
TokenizationError, Tokenizer, TokenizerSource,
quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters,
InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse,
InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef,
InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress,
RewindError, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError,
Tokenizer, TokenizerSource,
};

use serde::Serialize;
Expand Down
2 changes: 1 addition & 1 deletion crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ impl KnownModel for Bloom {
vec![]
}

fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
true
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl KnownModel for GptJ {
vec![]
}

fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
true
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl KnownModel for GptNeoX {
vec![]
}

fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
true
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ impl KnownModel for Llama {
vec![]
}

fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
true
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl KnownModel for Mpt {
vec![]
}

fn supports_delete(&self) -> bool {
fn supports_rewind(&self) -> bool {
true
}
}
Expand Down

0 comments on commit 2badcd9

Please sign in to comment.