From 2646aba7f8c151fbdc3eb7b6d72f326a7727c894 Mon Sep 17 00:00:00 2001 From: Steven Weiss Date: Wed, 5 Jul 2023 12:47:18 -0700 Subject: [PATCH] Refactor tests to put each type in it's own file. This'll help when I add my tests, because otherwise this file will explode in size --- binaries/llm-test/src/common.rs | 30 ++++++ binaries/llm-test/src/inference.rs | 114 +++++++++++++++++++++ binaries/llm-test/src/main.rs | 155 ++--------------------------- 3 files changed, 152 insertions(+), 147 deletions(-) create mode 100644 binaries/llm-test/src/common.rs create mode 100644 binaries/llm-test/src/inference.rs diff --git a/binaries/llm-test/src/common.rs b/binaries/llm-test/src/common.rs new file mode 100644 index 00000000..4c858820 --- /dev/null +++ b/binaries/llm-test/src/common.rs @@ -0,0 +1,30 @@ +//! Tests that are run on every model, regardless of config. + +pub(super) fn can_send(model: M) -> anyhow::Result { + let model = std::thread::spawn(move || model) + .join() + .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); + + log::info!("`can_send` test passed!"); + + model +} + +pub(super) fn can_roundtrip_hyperparameters( + model: &M, +) -> anyhow::Result<()> { + fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { + let mut data = vec![]; + hyperparameters.write_ggml(&mut data)?; + let new_hyperparameters = + ::read_ggml(&mut std::io::Cursor::new(data))?; + + assert_eq!(hyperparameters, &new_hyperparameters); + + log::info!("`can_roundtrip_hyperparameters` test passed!"); + + Ok(()) + } + + test_hyperparameters(model.hyperparameters()) +} diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs new file mode 100644 index 00000000..42d8f5be --- /dev/null +++ b/binaries/llm-test/src/inference.rs @@ -0,0 +1,114 @@ +//! Test cases for [crate::TestCase::Inference] tests. + +use std::{convert::Infallible, sync::Arc}; + +use llm::InferenceStats; + +use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta}; + +pub(super) fn can_infer( + model: &dyn llm::Model, + model_config: &ModelConfig, + input: &str, + expected_output: Option<&str>, + maximum_token_count: usize, +) -> anyhow::Result { + let mut session = model.start_session(Default::default()); + let (actual_output, res) = run_inference( + model, + model_config, + &mut session, + input, + maximum_token_count, + ); + + // Process the results + Ok(TestCaseReport { + meta: match &res { + Ok(_) => match expected_output { + Some(expected_output) => { + if expected_output == actual_output { + log::info!("`can_infer` test passed!"); + TestCaseReportMeta::Success + } else { + TestCaseReportMeta::Error { + error: "The output did not match the expected output.".to_string(), + } + } + } + None => { + log::info!("`can_infer` test passed (no expected output)!"); + TestCaseReportMeta::Success + } + }, + Err(err) => TestCaseReportMeta::Error { + error: err.to_string(), + }, + }, + report: TestCaseReportInner::Inference { + input: input.into(), + expect_output: expected_output.map(|s| s.to_string()), + actual_output, + inference_stats: res.ok(), + }, + }) +} + +fn run_inference( + model: &dyn llm::Model, + model_config: &ModelConfig, + session: &mut llm::InferenceSession, + input: &str, + maximum_token_count: usize, +) -> (String, Result) { + let mut actual_output: String = String::new(); + let res = session.infer::( + model, + &mut rand::rngs::mock::StepRng::new(0, 1), + &llm::InferenceRequest { + prompt: input.into(), + parameters: &llm::InferenceParameters { + n_threads: model_config.threads, + n_batch: 1, + sampler: Arc::new(DeterministicSampler), + }, + play_back_previous_tokens: false, + maximum_token_count: Some(maximum_token_count), + }, + &mut Default::default(), + |r| match r { + llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { + actual_output += &t; + Ok(llm::InferenceFeedback::Continue) + } + _ => Ok(llm::InferenceFeedback::Continue), + }, + ); + + (actual_output, res) +} + +#[derive(Debug)] +struct DeterministicSampler; +impl llm::Sampler for DeterministicSampler { + fn sample( + &self, + previous_tokens: &[llm::TokenId], + logits: &[f32], + _rng: &mut dyn rand::RngCore, + ) -> llm::TokenId { + // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` + // at all + let mut logits = logits.to_vec(); + for &token in previous_tokens { + logits[token as usize] = f32::NEG_INFINITY; + } + + logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0 as llm::TokenId + } +} diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index aa259779..60df0153 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -1,3 +1,8 @@ +//! Test runner for all LLMs. + +mod common; +mod inference; + use anyhow::Context; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; @@ -7,13 +12,11 @@ use serde::{Deserialize, Serialize}; use std::{ cmp::min, collections::HashMap, - convert::Infallible, env, fs::{self, File}, io::Write, path::{Path, PathBuf}, str::FromStr, - sync::Arc, time::Instant, }; @@ -240,10 +243,10 @@ async fn test_model( // // Confirm that the model can be sent to a thread, then sent back - let model = tests::can_send(model)?; + let model = common::can_send(model)?; // Confirm that the hyperparameters can be roundtripped - tests::can_roundtrip_hyperparameters(&model)?; + common::can_roundtrip_hyperparameters(&model)?; // @@ -259,7 +262,7 @@ async fn test_model( input, output, maximum_token_count, - } => test_case_reports.push(tests::can_infer( + } => test_case_reports.push(inference::can_infer( &model, model_config, input, @@ -320,148 +323,6 @@ fn write_report( Ok(()) } -mod tests { - use super::*; - - pub(super) fn can_send(model: M) -> anyhow::Result { - let model = std::thread::spawn(move || model) - .join() - .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); - - log::info!("`can_send` test passed!"); - - model - } - - pub(super) fn can_roundtrip_hyperparameters( - model: &M, - ) -> anyhow::Result<()> { - fn test_hyperparameters( - hyperparameters: &M, - ) -> anyhow::Result<()> { - let mut data = vec![]; - hyperparameters.write_ggml(&mut data)?; - let new_hyperparameters = - ::read_ggml(&mut std::io::Cursor::new(data))?; - - assert_eq!(hyperparameters, &new_hyperparameters); - - log::info!("`can_roundtrip_hyperparameters` test passed!"); - - Ok(()) - } - - test_hyperparameters(model.hyperparameters()) - } - - pub(super) fn can_infer( - model: &dyn llm::Model, - model_config: &ModelConfig, - input: &str, - expected_output: Option<&str>, - maximum_token_count: usize, - ) -> anyhow::Result { - let mut session = model.start_session(Default::default()); - let (actual_output, res) = run_inference( - model, - model_config, - &mut session, - input, - maximum_token_count, - ); - - // Process the results - Ok(TestCaseReport { - meta: match &res { - Ok(_) => match expected_output { - Some(expected_output) => { - if expected_output == actual_output { - log::info!("`can_infer` test passed!"); - TestCaseReportMeta::Success - } else { - TestCaseReportMeta::Error { - error: "The output did not match the expected output.".to_string(), - } - } - } - None => { - log::info!("`can_infer` test passed (no expected output)!"); - TestCaseReportMeta::Success - } - }, - Err(err) => TestCaseReportMeta::Error { - error: err.to_string(), - }, - }, - report: TestCaseReportInner::Inference { - input: input.into(), - expect_output: expected_output.map(|s| s.to_string()), - actual_output, - inference_stats: res.ok(), - }, - }) - } -} - -fn run_inference( - model: &dyn llm::Model, - model_config: &ModelConfig, - session: &mut llm::InferenceSession, - input: &str, - maximum_token_count: usize, -) -> (String, Result) { - let mut actual_output: String = String::new(); - let res = session.infer::( - model, - &mut rand::rngs::mock::StepRng::new(0, 1), - &llm::InferenceRequest { - prompt: input.into(), - parameters: &llm::InferenceParameters { - n_threads: model_config.threads, - n_batch: 1, - sampler: Arc::new(DeterministicSampler), - }, - play_back_previous_tokens: false, - maximum_token_count: Some(maximum_token_count), - }, - &mut Default::default(), - |r| match r { - llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { - actual_output += &t; - Ok(llm::InferenceFeedback::Continue) - } - _ => Ok(llm::InferenceFeedback::Continue), - }, - ); - - (actual_output, res) -} - -#[derive(Debug)] -struct DeterministicSampler; -impl llm::Sampler for DeterministicSampler { - fn sample( - &self, - previous_tokens: &[llm::TokenId], - logits: &[f32], - _rng: &mut dyn rand::RngCore, - ) -> llm::TokenId { - // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` - // at all - let mut logits = logits.to_vec(); - for &token in previous_tokens { - logits[token as usize] = f32::NEG_INFINITY; - } - - logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .unwrap() - .0 as llm::TokenId - } -} - async fn download_file(url: &str, local_path: &Path) -> anyhow::Result<()> { if local_path.exists() { return Ok(());