Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #334 from steventrouble/main
Browse files Browse the repository at this point in the history
Add ability to delete tokens (undo feed)
  • Loading branch information
philpax authored Jul 9, 2023
2 parents ae8233a + 2badcd9 commit 7f13bb9
Show file tree
Hide file tree
Showing 20 changed files with 498 additions and 156 deletions.
11 changes: 10 additions & 1 deletion binaries/llm-test/configs/bloom.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
"output_disabled": "When a llama rides a crab, ,.-\n\n/? ', , ; A;A = (b),d e orm\n“t” + “p。n unus et les el duetant alle that are by no ... ”\n( ? ) – ‘?\n!!\n«…..’,\nS.\n\n‘l」之 attergoir à dit-on pas .. 。。 ..\n– La leçon se confond quelquefois con ce qui es vée par occident .\n( 2 ) .\nLa protestation del paysan mécontent regardait pendre eussent mœurs faillite forteresse rivières lieues forteressemelés inquiétudes crackdown brawl slaughter massacresokea .\n» » … « …\n. . . \" \" ….",
"maximum_token_count": 128
}
},
{
"Tokens": {
"input": "Rustformers is",
"output": 15
}
},
{
"Delete": {}
}
]
}
}
11 changes: 10 additions & 1 deletion binaries/llm-test/configs/gptj.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
"output_disabled": "\"When a llama rides a crab, \nit's not the same as when an elephant does it.\" - John Steinbeck, East of Eden.\n\n \"The best way to predict your future is by looking at history.\"- Robert Kiyosaki (author). Rich Dad Poor dad : what 10 rules for success really mean and how you can apply them in life! The rich dads guidebook on personal finance: How To Become A Millionaire In Less Than 5 years! http://www..richdadpoordaddyguidebooksalexanderkimballblogcom/the_bestwaytopredictyourfutureislookingathistory/. You will learn about money management",
"maximum_token_count": 128
}
},
{
"Tokens": {
"input": "Rustformers is",
"output": 257
}
},
{
"Delete": {}
}
]
}
}
11 changes: 10 additions & 1 deletion binaries/llm-test/configs/gptneox.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
"output_disabled": "<|padding|>When a llama rides a crab, \n“The Greatest Show on Earth” is the title of an 1875 book by Phineas Taylor Barnum, who founded and operated The circus. He was born in Bethel Connecticut to Meshack (Meshake) Bowman Jr., from New York City; his mother’s name has not been recorded but she may have had some Native American ancestry as well.[2] His father died when he[3][4], at age three,[5]: 9–10 (p1), 11-12​—was left with relatives until they could find him work or send for them back home where there",
"maximum_token_count": 128
}
},
{
"Tokens": {
"input": "Rustformers is",
"output": 247
}
},
{
"Delete": {}
}
]
}
}
11 changes: 10 additions & 1 deletion binaries/llm-test/configs/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
"output": "When a llama rides a crab, 10-year olds are the ones who get to eat.\nTheir parents have been told that they will be eating for another year or two before their children can enjoy it again – and then only if there is enough food left over from Christmas dinner!",
"maximum_token_count": 128
}
},
{
"Tokens": {
"input": "Rustformers is",
"output": 260
}
},
{
"Delete": {}
}
]
}
}
11 changes: 10 additions & 1 deletion binaries/llm-test/configs/mpt.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
"output": "When a llama rides a crab,  the llama is called the \"crab rider\".\nThe crabs are very popular in South America, especially Brazil. They have been used as transportation for many years and they can carry up to five people at once!",
"maximum_token_count": 128
}
},
{
"Tokens": {
"input": "Rustformers is",
"output": 247
}
},
{
"Delete": {}
}
]
}
}
30 changes: 30 additions & 0 deletions binaries/llm-test/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//! Tests that are run on every model, regardless of config.

pub(super) fn can_send<M: llm::KnownModel + 'static>(model: M) -> anyhow::Result<M> {
let model = std::thread::spawn(move || model)
.join()
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"));

log::info!("`can_send` test passed!");

model
}

pub(super) fn can_roundtrip_hyperparameters<M: llm::KnownModel + 'static>(
model: &M,
) -> anyhow::Result<()> {
fn test_hyperparameters<M: llm::Hyperparameters>(hyperparameters: &M) -> anyhow::Result<()> {
let mut data = vec![];
hyperparameters.write_ggml(&mut data)?;
let new_hyperparameters =
<M as llm::Hyperparameters>::read_ggml(&mut std::io::Cursor::new(data))?;

assert_eq!(hyperparameters, &new_hyperparameters);

log::info!("`can_roundtrip_hyperparameters` test passed!");

Ok(())
}

test_hyperparameters(model.hyperparameters())
}
95 changes: 95 additions & 0 deletions binaries/llm-test/src/delete.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! Tests the model's token manipulation APIs:
//!
//! * [llm::InferenceSession::feed_prompt()]
//!
//! See [crate::TestCase::Tokens].

use std::convert::Infallible;

use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest};
use serde::Serialize;

use crate::{TestCaseReport, TestCaseReportMeta};

/// Tests that models can delete tokens without changing the model's behavior.
pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport {
let report = DeleteReport::default();
let mut session = model.start_session(Default::default());
let mut output = OutputRequest {
all_logits: Some(vec![]),
..Default::default()
};

// Feed some tokens
if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}

// Add token and get the logits
if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}
let Some(original_logits) = output.all_logits.clone() else {
return report.failure("Model did not return logits.");
};

// Rewind, then re-add. Verify logits are the same.
if let Err(err) = session.rewind(model, 1) {
return report.failure(&err.to_string());
}
if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}
let Some(redone_logits) = output.all_logits.clone() else {
return report.failure("Second run of model did not return logits.");
};

// Compare the logits
for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() {
if original > redone + f32::EPSILON || original < redone - f32::EPSILON {
return report.failure(&format!(
"Expected logits to be the same after delete, but differed at {idx}, \
expected {original}, but was {redone}."
));
}
}

log::info!("`can_delete` test passed!");
report.success()
}

fn feed_prompt(
prompt: &str,
session: &mut InferenceSession,
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, always_continue)
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Ok(InferenceFeedback::Continue)
}

#[derive(Serialize, Default)]
pub struct DeleteReport {
output: usize,
}

impl DeleteReport {
fn failure(self, msg: &str) -> TestCaseReport {
TestCaseReport {
meta: TestCaseReportMeta::Error {
error: msg.to_owned(),
},
report: crate::TestCaseReportInner::Delete(self),
}
}

fn success(self) -> TestCaseReport {
TestCaseReport {
meta: TestCaseReportMeta::Success,
report: crate::TestCaseReportInner::Delete(self),
}
}
}
116 changes: 116 additions & 0 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//! Tests the model's inference APIs.
//!
//! See [crate::TestCase::Inference].

use std::{convert::Infallible, sync::Arc};

use llm::InferenceStats;

use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta};

pub(crate) fn can_infer(
model: &dyn llm::Model,
model_config: &ModelConfig,
input: &str,
expected_output: Option<&str>,
maximum_token_count: usize,
) -> anyhow::Result<TestCaseReport> {
let mut session = model.start_session(Default::default());
let (actual_output, res) = run_inference(
model,
model_config,
&mut session,
input,
maximum_token_count,
);

// Process the results
Ok(TestCaseReport {
meta: match &res {
Ok(_) => match expected_output {
Some(expected_output) => {
if expected_output == actual_output {
log::info!("`can_infer` test passed!");
TestCaseReportMeta::Success
} else {
TestCaseReportMeta::Error {
error: "The output did not match the expected output.".to_string(),
}
}
}
None => {
log::info!("`can_infer` test passed (no expected output)!");
TestCaseReportMeta::Success
}
},
Err(err) => TestCaseReportMeta::Error {
error: err.to_string(),
},
},
report: TestCaseReportInner::Inference {
input: input.into(),
expect_output: expected_output.map(|s| s.to_string()),
actual_output,
inference_stats: res.ok(),
},
})
}

fn run_inference(
model: &dyn llm::Model,
model_config: &ModelConfig,
session: &mut llm::InferenceSession,
input: &str,
maximum_token_count: usize,
) -> (String, Result<InferenceStats, llm::InferenceError>) {
let mut actual_output: String = String::new();
let res = session.infer::<Infallible>(
model,
&mut rand::rngs::mock::StepRng::new(0, 1),
&llm::InferenceRequest {
prompt: input.into(),
parameters: &llm::InferenceParameters {
n_threads: model_config.threads,
n_batch: 1,
sampler: Arc::new(DeterministicSampler),
},
play_back_previous_tokens: false,
maximum_token_count: Some(maximum_token_count),
},
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
actual_output += &t;
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
},
);

(actual_output, res)
}

#[derive(Debug)]
struct DeterministicSampler;
impl llm::Sampler for DeterministicSampler {
fn sample(
&self,
previous_tokens: &[llm::TokenId],
logits: &[f32],
_rng: &mut dyn rand::RngCore,
) -> llm::TokenId {
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
let mut logits = logits.to_vec();
for &token in previous_tokens {
logits[token as usize] = f32::NEG_INFINITY;
}

logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as llm::TokenId
}
}
Loading

0 comments on commit 7f13bb9

Please sign in to comment.