Skip to content

Commit

Permalink
Add framework for CacheEngine
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 4, 2023
1 parent fae2610 commit a186d85
Show file tree
Hide file tree
Showing 9 changed files with 473 additions and 268 deletions.
54 changes: 9 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,20 @@ use openai::pipelines::{
pub enum ModelSelected {
/// Select the llama7b model.
Llama7b {
#[arg(long)]
no_kv_cache: bool,
#[arg(long)]
repeat_last_n: usize,
#[arg(long)]
use_flash_attn: bool,
},

/// Select the llama13b model.
Llama13b {
#[arg(long)]
no_kv_cache: bool,
#[arg(long)]
repeat_last_n: usize,
#[arg(long)]
use_flash_attn: bool,
},

/// Select the llama70b model.
Llama70b {
#[arg(long)]
no_kv_cache: bool,
#[arg(long)]
repeat_last_n: usize,
#[arg(long)]
use_flash_attn: bool,
},

/// Select the mistral7b model.
Expand All @@ -53,21 +41,9 @@ pub enum ModelSelected {
impl ToString for ModelSelected {
fn to_string(&self) -> String {
match self {
ModelSelected::Llama7b {
no_kv_cache: _,
repeat_last_n: _,
use_flash_attn: _,
} => "llama7b".to_string(),
ModelSelected::Llama13b {
no_kv_cache: _,
repeat_last_n: _,
use_flash_attn: _,
} => "llama13b".to_string(),
ModelSelected::Llama70b {
no_kv_cache: _,
repeat_last_n: _,
use_flash_attn: _,
} => "llama70b".to_string(),
ModelSelected::Llama7b { repeat_last_n: _ } => "llama7b".to_string(),
ModelSelected::Llama13b { repeat_last_n: _ } => "llama13b".to_string(),
ModelSelected::Llama70b { repeat_last_n: _ } => "llama70b".to_string(),
ModelSelected::Mistral7b {
repeat_penalty: _,
repeat_last_n: _,
Expand All @@ -79,35 +55,23 @@ impl ToString for ModelSelected {

pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box<dyn ModelLoader<'a>>, String) {
match selected_model {
ModelSelected::Llama7b {
no_kv_cache,
repeat_last_n,
use_flash_attn,
} => (
ModelSelected::Llama7b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(no_kv_cache, repeat_last_n, use_flash_attn),
LlamaSpecificConfig::new(repeat_last_n),
"llama7b".to_string(),
)),
"meta-llama/Llama-27b-chat-hf".to_string(),
),
ModelSelected::Llama13b {
no_kv_cache,
repeat_last_n,
use_flash_attn,
} => (
ModelSelected::Llama13b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(no_kv_cache, repeat_last_n, use_flash_attn),
LlamaSpecificConfig::new(repeat_last_n),
"llama13b".to_string(),
)),
"meta-llama/Llama-213b-chat-hf".to_string(),
),
ModelSelected::Llama70b {
no_kv_cache,
repeat_last_n,
use_flash_attn,
} => (
ModelSelected::Llama70b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(no_kv_cache, repeat_last_n, use_flash_attn),
LlamaSpecificConfig::new(repeat_last_n),
"llama70b".to_string(),
)),
"meta-llama/Llama-270b-chat-hf".to_string(),
Expand Down
Loading

0 comments on commit a186d85

Please sign in to comment.