Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Refactor tests to put each type in it's own file.
Browse files Browse the repository at this point in the history
This'll help when I add my tests, because otherwise this file will explode in size
  • Loading branch information
steventrouble committed Jul 7, 2023
1 parent 2d939e6 commit 2646aba
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 147 deletions.
30 changes: 30 additions & 0 deletions binaries/llm-test/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//! Tests that are run on every model, regardless of config.

pub(super) fn can_send<M: llm::KnownModel + 'static>(model: M) -> anyhow::Result<M> {
let model = std::thread::spawn(move || model)
.join()
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"));

log::info!("`can_send` test passed!");

model
}

pub(super) fn can_roundtrip_hyperparameters<M: llm::KnownModel + 'static>(
model: &M,
) -> anyhow::Result<()> {
fn test_hyperparameters<M: llm::Hyperparameters>(hyperparameters: &M) -> anyhow::Result<()> {
let mut data = vec![];
hyperparameters.write_ggml(&mut data)?;
let new_hyperparameters =
<M as llm::Hyperparameters>::read_ggml(&mut std::io::Cursor::new(data))?;

assert_eq!(hyperparameters, &new_hyperparameters);

log::info!("`can_roundtrip_hyperparameters` test passed!");

Ok(())
}

test_hyperparameters(model.hyperparameters())
}
114 changes: 114 additions & 0 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//! Test cases for [crate::TestCase::Inference] tests.

use std::{convert::Infallible, sync::Arc};

use llm::InferenceStats;

use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta};

pub(super) fn can_infer(
model: &dyn llm::Model,
model_config: &ModelConfig,
input: &str,
expected_output: Option<&str>,
maximum_token_count: usize,
) -> anyhow::Result<TestCaseReport> {
let mut session = model.start_session(Default::default());
let (actual_output, res) = run_inference(
model,
model_config,
&mut session,
input,
maximum_token_count,
);

// Process the results
Ok(TestCaseReport {
meta: match &res {
Ok(_) => match expected_output {
Some(expected_output) => {
if expected_output == actual_output {
log::info!("`can_infer` test passed!");
TestCaseReportMeta::Success
} else {
TestCaseReportMeta::Error {
error: "The output did not match the expected output.".to_string(),
}
}
}
None => {
log::info!("`can_infer` test passed (no expected output)!");
TestCaseReportMeta::Success
}
},
Err(err) => TestCaseReportMeta::Error {
error: err.to_string(),
},
},
report: TestCaseReportInner::Inference {
input: input.into(),
expect_output: expected_output.map(|s| s.to_string()),
actual_output,
inference_stats: res.ok(),
},
})
}

fn run_inference(
model: &dyn llm::Model,
model_config: &ModelConfig,
session: &mut llm::InferenceSession,
input: &str,
maximum_token_count: usize,
) -> (String, Result<InferenceStats, llm::InferenceError>) {
let mut actual_output: String = String::new();
let res = session.infer::<Infallible>(
model,
&mut rand::rngs::mock::StepRng::new(0, 1),
&llm::InferenceRequest {
prompt: input.into(),
parameters: &llm::InferenceParameters {
n_threads: model_config.threads,
n_batch: 1,
sampler: Arc::new(DeterministicSampler),
},
play_back_previous_tokens: false,
maximum_token_count: Some(maximum_token_count),
},
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
actual_output += &t;
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
},
);

(actual_output, res)
}

#[derive(Debug)]
struct DeterministicSampler;
impl llm::Sampler for DeterministicSampler {
fn sample(
&self,
previous_tokens: &[llm::TokenId],
logits: &[f32],
_rng: &mut dyn rand::RngCore,
) -> llm::TokenId {
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
let mut logits = logits.to_vec();
for &token in previous_tokens {
logits[token as usize] = f32::NEG_INFINITY;
}

logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as llm::TokenId
}
}
155 changes: 8 additions & 147 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
//! Test runner for all LLMs.

mod common;
mod inference;

use anyhow::Context;
use clap::Parser;
use indicatif::{ProgressBar, ProgressStyle};
Expand All @@ -7,13 +12,11 @@ use serde::{Deserialize, Serialize};
use std::{
cmp::min,
collections::HashMap,
convert::Infallible,
env,
fs::{self, File},
io::Write,
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
time::Instant,
};

Expand Down Expand Up @@ -240,10 +243,10 @@ async fn test_model(
//

// Confirm that the model can be sent to a thread, then sent back
let model = tests::can_send(model)?;
let model = common::can_send(model)?;

// Confirm that the hyperparameters can be roundtripped
tests::can_roundtrip_hyperparameters(&model)?;
common::can_roundtrip_hyperparameters(&model)?;

//

Expand All @@ -259,7 +262,7 @@ async fn test_model(
input,
output,
maximum_token_count,
} => test_case_reports.push(tests::can_infer(
} => test_case_reports.push(inference::can_infer(
&model,
model_config,
input,
Expand Down Expand Up @@ -320,148 +323,6 @@ fn write_report(
Ok(())
}

mod tests {
use super::*;

pub(super) fn can_send<M: llm::KnownModel + 'static>(model: M) -> anyhow::Result<M> {
let model = std::thread::spawn(move || model)
.join()
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"));

log::info!("`can_send` test passed!");

model
}

pub(super) fn can_roundtrip_hyperparameters<M: llm::KnownModel + 'static>(
model: &M,
) -> anyhow::Result<()> {
fn test_hyperparameters<M: llm::Hyperparameters>(
hyperparameters: &M,
) -> anyhow::Result<()> {
let mut data = vec![];
hyperparameters.write_ggml(&mut data)?;
let new_hyperparameters =
<M as llm::Hyperparameters>::read_ggml(&mut std::io::Cursor::new(data))?;

assert_eq!(hyperparameters, &new_hyperparameters);

log::info!("`can_roundtrip_hyperparameters` test passed!");

Ok(())
}

test_hyperparameters(model.hyperparameters())
}

pub(super) fn can_infer(
model: &dyn llm::Model,
model_config: &ModelConfig,
input: &str,
expected_output: Option<&str>,
maximum_token_count: usize,
) -> anyhow::Result<TestCaseReport> {
let mut session = model.start_session(Default::default());
let (actual_output, res) = run_inference(
model,
model_config,
&mut session,
input,
maximum_token_count,
);

// Process the results
Ok(TestCaseReport {
meta: match &res {
Ok(_) => match expected_output {
Some(expected_output) => {
if expected_output == actual_output {
log::info!("`can_infer` test passed!");
TestCaseReportMeta::Success
} else {
TestCaseReportMeta::Error {
error: "The output did not match the expected output.".to_string(),
}
}
}
None => {
log::info!("`can_infer` test passed (no expected output)!");
TestCaseReportMeta::Success
}
},
Err(err) => TestCaseReportMeta::Error {
error: err.to_string(),
},
},
report: TestCaseReportInner::Inference {
input: input.into(),
expect_output: expected_output.map(|s| s.to_string()),
actual_output,
inference_stats: res.ok(),
},
})
}
}

fn run_inference(
model: &dyn llm::Model,
model_config: &ModelConfig,
session: &mut llm::InferenceSession,
input: &str,
maximum_token_count: usize,
) -> (String, Result<InferenceStats, llm::InferenceError>) {
let mut actual_output: String = String::new();
let res = session.infer::<Infallible>(
model,
&mut rand::rngs::mock::StepRng::new(0, 1),
&llm::InferenceRequest {
prompt: input.into(),
parameters: &llm::InferenceParameters {
n_threads: model_config.threads,
n_batch: 1,
sampler: Arc::new(DeterministicSampler),
},
play_back_previous_tokens: false,
maximum_token_count: Some(maximum_token_count),
},
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
actual_output += &t;
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
},
);

(actual_output, res)
}

#[derive(Debug)]
struct DeterministicSampler;
impl llm::Sampler for DeterministicSampler {
fn sample(
&self,
previous_tokens: &[llm::TokenId],
logits: &[f32],
_rng: &mut dyn rand::RngCore,
) -> llm::TokenId {
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
let mut logits = logits.to_vec();
for &token in previous_tokens {
logits[token as usize] = f32::NEG_INFINITY;
}

logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as llm::TokenId
}
}

async fn download_file(url: &str, local_path: &Path) -> anyhow::Result<()> {
if local_path.exists() {
return Ok(());
Expand Down

0 comments on commit 2646aba

Please sign in to comment.