Skip to content

Commit

Permalink
Allow bootable server
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 24, 2023
1 parent d161aee commit 4d8a380
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
17 changes: 17 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ pub enum ModelSelected {
},
}

impl ToString for ModelSelected {
fn to_string(&self) -> String {
match self {
ModelSelected::Llama {
no_kv_cache: _,
repeat_last_n: _,
use_flash_attn: _,
} => "llama".to_string(),
ModelSelected::Mistral {
repeat_penalty: _,
repeat_last_n: _,
use_flash_attn: _,
} => "mistral".to_string(),
}
}
}

pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box<dyn ModelLoader<'a>>, String) {
match selected_model {
ModelSelected::Llama {
Expand Down
16 changes: 9 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::collections::HashMap;
use std::sync::Mutex;

use actix_web::web::Data;
use actix_web::{http::header::ContentType, test, App};
use actix_web::{test, App};
use candle_core::{DType, Device};
use candle_vllm::openai::openai_server::chat_completions;
use candle_vllm::openai::requests::Messages;
use candle_vllm::openai::responses::APIError;
use candle_vllm::openai::{self, OpenAIServerData};
use candle_vllm::openai::OpenAIServerData;
use candle_vllm::{get_model_loader, ModelSelected};
use clap::Parser;

Expand All @@ -26,6 +24,8 @@ struct Args {
async fn main() -> Result<(), APIError> {
let args = Args::parse();

println!("Loading {} model...", args.command.to_string());

let (loader, model_id) = get_model_loader(args.command);
let paths = loader.download_model(model_id, None, args.hf_token)?;
let model = loader.load_model(paths, DType::F16, Device::Cpu)?;
Expand All @@ -36,14 +36,16 @@ async fn main() -> Result<(), APIError> {
device: Device::Cpu,
};

let app = test::init_service(
println!("Starting server...");

let _app = test::init_service(
App::new()
.service(chat_completions)
.app_data(Data::new(server_data)),
)
.await;

let mut system = HashMap::new();
/*let mut system = HashMap::new();
system.insert("role".to_string(), "system".to_string());
system.insert(
"content".to_string(),
Expand Down Expand Up @@ -83,6 +85,6 @@ async fn main() -> Result<(), APIError> {
let resp = test::call_service(&app, req).await;
println!("{:?}", resp.status());
println!("{:?}", resp.into_body());
println!("{:?}", resp.into_body());*/
Ok(())
}

0 comments on commit 4d8a380

Please sign in to comment.