Skip to content

Commit

Permalink
feat: multiple models
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed May 7, 2024
1 parent 3e7fd66 commit 2f281bb
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 97 deletions.
5 changes: 3 additions & 2 deletions .github/scripts/e2e-test-server-linux-and-mac.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ response1=$(curl --connect-timeout 60 -o /tmp/load-llm-model-res.log -s -w "%{ht
--header 'Content-Type: application/json' \
--data '{
"llama_model_path": "/tmp/testllm",
"model_alias": "testllm",
"ctx_len": 50,
"ngl": 32,
"embedding": false
Expand All @@ -73,7 +74,7 @@ response2=$(
{"content": "Write a long and sad story for me", "role": "user"}
],
"stream": true,
"model": "gpt-3.5-turbo",
"model": "testllm",
"max_tokens": 50,
"stop": ["hello"],
"frequency_penalty": 0,
Expand All @@ -83,7 +84,7 @@ response2=$(
)

# unload model
response3=$(curl --connect-timeout 60 -o /tmp/unload-model-res.log --request GET -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/unloadmodel" \
response3=$(curl --connect-timeout 60 -o /tmp/unload-model-res.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/unloadmodel" \
--header 'Content-Type: application/json' \
--data '{
"llama_model_path": "/tmp/testllm"
Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/e2e-test-server-windows.bat
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ rem Define JSON strings for curl data
call set "MODEL_LLM_PATH_STRING=%%MODEL_LLM_PATH:\=\\%%"
call set "MODEL_EMBEDDING_PATH_STRING=%%MODEL_EMBEDDING_PATH:\=\\%%"
set "curl_data1={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}"
set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":false,\"model\":\"gpt-3.5-turbo\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}"
set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":false,\"model\":\"testllm\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}"
set "curl_data3={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}"
set "curl_data4={\"llama_model_path\":\"%MODEL_EMBEDDING_PATH_STRING%\", \"embedding\": true, \"model_type\": \"embedding\"}"
set "curl_data5={\"input\": \"Hello\", \"model\": \"test-embedding\", \"encoding_format\": \"float\"}"
Expand All @@ -82,7 +82,7 @@ curl.exe --connect-timeout 60 -o "%TEMP%\response2.log" -s -w "%%{http_code}" --
--header "Content-Type: application/json" ^
--data "%curl_data2%" > %TEMP%\response2.log 2>&1

curl.exe --connect-timeout 60 -o "%TEMP%\response3.log" --request GET -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/unloadmodel" --header "Content-Type: application/json" --data "%curl_data3%" > %TEMP%\response3.log 2>&1
curl.exe --connect-timeout 60 -o "%TEMP%\response3.log" -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/unloadmodel" --header "Content-Type: application/json" --data "%curl_data3%" > %TEMP%\response3.log 2>&1

curl.exe --connect-timeout 60 -o "%TEMP%\response4.log" --request POST -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/loadmodel" --header "Content-Type: application/json" --data "%curl_data4%" > %TEMP%\response4.log 2>&1

Expand Down
5 changes: 3 additions & 2 deletions examples/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ int main(int argc, char** argv) {
};

svr->Post("/loadmodel", handle_load_model);
svr->Get("/unloadmodel", handle_unload_model);
// Use POST since httplib does not read request body for GET method
svr->Post("/unloadmodel", handle_unload_model);
svr->Post("/v1/chat/completions", handle_completions);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Get("/modelstatus", handle_get_model_status);
svr->Post("/modelstatus", handle_get_model_status);

LOG_INFO << "HTTP server listening: " << hostname << ":" << port;
svr->new_task_queue = [] {
Expand Down
175 changes: 92 additions & 83 deletions src/LlamaEngine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@ std::string create_return_json(const std::string& id, const std::string& model,
} // namespace

LlamaEngine::~LlamaEngine() {
StopBackgroundTask();
}

void LlamaEngine::HandleChatCompletion(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
// Check if model is loaded
if (CheckModelLoaded(callback)) {
if (CheckModelLoaded(callback, llama_utils::getModelId(*jsonBody))) {
// Model is loaded
// Do Inference
HandleInferenceImpl(llama::inferences::fromJson(jsonBody),
Expand All @@ -132,7 +131,7 @@ void LlamaEngine::HandleEmbedding(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
// Check if model is loaded
if (CheckModelLoaded(callback)) {
if (CheckModelLoaded(callback, llama_utils::getModelId(*jsonBody))) {
// Run embedding
HandleEmbeddingImpl(jsonBody, std::move(callback));
}
Expand All @@ -156,7 +155,22 @@ void LlamaEngine::LoadModel(
return;
}

if (llama_.model_loaded_external) {
auto model_id = llama_utils::getModelId(*jsonBody);
if (model_id.empty()) {
LOG_INFO << "Model id is empty in request";
Json::Value jsonResp;
jsonResp["message"] = "No model id found in request body";
Json::Value status;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k400BadRequest;
callback(std::move(status), std::move(jsonResp));
return;
}

if (auto l = server_ctx_map_.find(model_id);
l != server_ctx_map_.end() && l->second.model_loaded_external) {
LOG_INFO << "Model already loaded";
Json::Value jsonResp;
jsonResp["message"] = "Model already loaded";
Expand Down Expand Up @@ -189,22 +203,18 @@ void LlamaEngine::LoadModel(
status["is_stream"] = false;
status["status_code"] = k200OK;
callback(std::move(status), std::move(jsonResp));
LOG_INFO << "Model loaded successfully";
LOG_INFO << "Model loaded successfully: " << model_id;
}
}

void LlamaEngine::UnloadModel(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
auto model_id = llama_utils::getModelId(*jsonBody);
if (CheckModelLoaded(callback, model_id)) {
auto& l = server_ctx_map_[model_id];
l.release_resources();

if (CheckModelLoaded(callback)) {
StopBackgroundTask();

llama_free(llama_.ctx);
llama_free_model(llama_.model);
llama_.ctx = nullptr;
llama_.model = nullptr;
llama_backend_free();
Json::Value jsonResp;
jsonResp["message"] = "Model unloaded successfully";
Json::Value status;
Expand All @@ -214,6 +224,8 @@ void LlamaEngine::UnloadModel(
status["status_code"] = k200OK;
callback(std::move(status), std::move(jsonResp));

server_ctx_map_.erase(model_id);
ifr_task_queue_map_.erase(model_id);
LOG_INFO << "Model unloaded successfully";
}
}
Expand All @@ -222,11 +234,13 @@ void LlamaEngine::GetModelStatus(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {

bool is_model_loaded = llama_.model_loaded_external;
if (CheckModelLoaded(callback)) {
auto model_id = llama_utils::getModelId(*jsonBody);
if (auto is_loaded = CheckModelLoaded(callback, model_id); is_loaded) {
// CheckModelLoaded gurantees that model_id exists in server_ctx_map;
auto l = server_ctx_map_.find(model_id);
Json::Value jsonResp;
jsonResp["model_loaded"] = is_model_loaded;
jsonResp["model_data"] = llama_.get_model_props().dump();
jsonResp["model_loaded"] = is_loaded;
jsonResp["model_data"] = l->second.get_model_props().dump();
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
Expand Down Expand Up @@ -285,11 +299,6 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
params.embedding = jsonBody->get("embedding", true).asBool();
model_type = jsonBody->get("model_type", "llm").asString();
if (model_type == "llm") {
llama_.model_type = ModelType::LLM;
} else {
llama_.model_type = ModelType::EMBEDDING;
}
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
params.n_batch = jsonBody->get("n_batch", 512).asInt();
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
Expand Down Expand Up @@ -318,6 +327,7 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
params.model_alias = params.model;
}

if (ShouldInitBackend()) {
llama_backend_init();

// LOG_INFO_LLAMA("build info",
Expand All @@ -328,34 +338,37 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});

}
auto model_id = llama_utils::getModelId(*jsonBody);
// load the model
if (!llama_.load_model(params)) {
if (!server_ctx_map_[model_id].load_model(params)) {
LOG_ERROR << "Error loading the model";
return false; // Indicate failure
}
llama_.initialize();

queue_ = std::make_unique<trantor::ConcurrentTaskQueue>(params.n_parallel,
"llamaCPP");

llama_.model_loaded_external = true;
if (model_type == "llm") {
server_ctx_map_[model_id].model_type = ModelType::LLM;
} else {
server_ctx_map_[model_id].model_type = ModelType::EMBEDDING;
}
server_ctx_map_[model_id].initialize();

LOG_INFO << "Started background task here!";
bgr_thread_ = std::thread(&LlamaEngine::HandleBackgroundTask, this);
ifr_task_queue_map_.emplace(model_id, std::make_unique<trantor::ConcurrentTaskQueue>(
params.n_parallel, model_id));

// For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model.
// So we use this variable to differentiate with other models
if (llama_.model_type == ModelType::LLM) {
WarmUpModel();
if (server_ctx_map_[model_id].model_type == ModelType::LLM) {
WarmUpModel(model_id);
}
return true;
}

void LlamaEngine::HandleInferenceImpl(
llama::inferences::ChatCompletionRequest&& completion,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
if (llama_.model_type == ModelType::EMBEDDING) {
assert(server_ctx_map_.find(completion.model_id) != server_ctx_map_.end());
auto& l = server_ctx_map_[completion.model_id];
if (l.model_type == ModelType::EMBEDDING) {
LOG_WARN << "Not support completion for embedding model";
Json::Value jsonResp;
jsonResp["message"] = "Not support completion for embedding model";
Expand Down Expand Up @@ -398,7 +411,7 @@ void LlamaEngine::HandleInferenceImpl(
data["grammar"] = grammar_file_content_;
};

if (!llama_.multimodal) {
if (!l.multimodal) {
for (const auto& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
Expand Down Expand Up @@ -508,10 +521,10 @@ void LlamaEngine::HandleInferenceImpl(
if (is_streamed) {
LOG_INFO << "Request " << request_id << ": "
<< "Streamed, waiting for respone";
auto state = create_inference_state(llama_);
auto state = create_inference_state(l);

// Queued task
queue_->runTaskInQueue([cb = std::move(callback), state, data,
ifr_task_queue_map_[completion.model_id]->runTaskInQueue([cb = std::move(callback), state, data,
request_id]() {
state->task_id = state->llama.request_completion(data, false, false, -1);
while (state->llama.model_loaded_external) {
Expand Down Expand Up @@ -589,16 +602,17 @@ void LlamaEngine::HandleInferenceImpl(
<< "Inference completed";
});
} else {
queue_->runTaskInQueue(
[this, request_id, cb = std::move(callback), d = std::move(data)]() {
auto state = create_inference_state(l);
ifr_task_queue_map_[completion.model_id]->runTaskInQueue(
[this, request_id, state, cb = std::move(callback), d = std::move(data)]() {
Json::Value respData;
int task_id = llama_.request_completion(d, false, false, -1);
int task_id = state->llama.request_completion(d, false, false, -1);
LOG_INFO << "Request " << request_id << ": "
<< "Non stream, waiting for respone";
if (!json_value(d, "stream", false)) {
bool has_error = false;
std::string completion_text;
task_result result = llama_.next_result(task_id);
task_result result = state->llama.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
Expand Down Expand Up @@ -630,32 +644,34 @@ void LlamaEngine::HandleInferenceImpl(
void LlamaEngine::HandleEmbeddingImpl(
std::shared_ptr<Json::Value> jsonBody,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
auto model_id = llama_utils::getModelId(*jsonBody);
assert(server_ctx_map_.find(model_id) != server_ctx_map_.end());
int request_id = ++no_of_requests;
LOG_INFO << "Request " << request_id << ": "
<< "Generating reponse for embedding request";
// Queue embedding task
auto state = create_inference_state(llama_);
auto state = create_inference_state(server_ctx_map_[model_id]);

queue_->runTaskInQueue([this, state, jsonBody, callback, request_id]() {
ifr_task_queue_map_[model_id]->runTaskInQueue([this, state, jsonBody, callback, request_id]() {
Json::Value responseData(Json::arrayValue);

if (jsonBody->isMember("input")) {
const Json::Value& input = (*jsonBody)["input"];
if (input.isString()) {
// Process the single string input
state->task_id = llama_.request_completion(
state->task_id = state->llama.request_completion(
{{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1);
task_result result = llama_.next_result(state->task_id);
task_result result = state->llama.next_result(state->task_id);
std::vector<float> embedding_result = result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
} else if (input.isArray()) {
// Process each element in the array input
for (const auto& elem : input) {
if (elem.isString()) {
const int task_id = llama_.request_completion(
const int task_id = state->llama.request_completion(
{{"prompt", elem.asString()}, {"n_predict", 0}}, false, true,
-1);
task_result result = llama_.next_result(task_id);
task_result result = state->llama.next_result(task_id);
std::vector<float> embedding_result =
result.result_json["embedding"];
responseData.append(create_embedding_payload(embedding_result, 0));
Expand Down Expand Up @@ -685,9 +701,12 @@ void LlamaEngine::HandleEmbeddingImpl(
}

bool LlamaEngine::CheckModelLoaded(
std::function<void(Json::Value&&, Json::Value&&)>& callback) {
if (!llama_.model_loaded_external) {
LOG_ERROR << "Model has not been loaded";
std::function<void(Json::Value&&, Json::Value&&)>& callback, const std::string& model_id) {
if (auto l = server_ctx_map_.find(model_id);
l == server_ctx_map_.end() || !l->second.model_loaded_external) {
LOG_ERROR << "Error: model_id: " << model_id
<< ", existed: " << (l == server_ctx_map_.end())
<< ", loaded: " << (l != server_ctx_map_.end());
Json::Value jsonResp;
jsonResp["message"] =
"Model has not been loaded, please load model into nitro";
Expand All @@ -702,42 +721,32 @@ bool LlamaEngine::CheckModelLoaded(
return true;
}

void LlamaEngine::WarmUpModel() {
json pseudo;

LOG_INFO << "Warm-up model";
pseudo["prompt"] = "Hello";
pseudo["n_predict"] = 2;
pseudo["stream"] = false;
const int task_id = llama_.request_completion(pseudo, false, false, -1);
std::string completion_text;
task_result result = llama_.next_result(task_id);
if (!result.error && result.stop) {
LOG_INFO << result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace);
}
void LlamaEngine::WarmUpModel(const std::string& model_id) {
if (auto l = server_ctx_map_.find(model_id); l != server_ctx_map_.end()) {
json pseudo;

LOG_INFO << "Warm-up model: " << model_id;
pseudo["prompt"] = "Hello";
pseudo["n_predict"] = 2;
pseudo["stream"] = false;
const int task_id = l->second.request_completion(pseudo, false, false, -1);
task_result result = l->second.next_result(task_id);
if (!result.error && result.stop) {
LOG_INFO << result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace);
}
} else {
LOG_WARN << "Model not found " << model_id;
}

void LlamaEngine::HandleBackgroundTask() {
while (llama_.model_loaded_external) {
// model_loaded =
llama_.update_slots();
}
LOG_INFO << "Background task stopped! ";
llama_.kv_cache_clear();
LOG_INFO << "KV cache cleared!";
}

void LlamaEngine::StopBackgroundTask() {
if (llama_.model_loaded_external) {
llama_.model_loaded_external = false;
llama_.condition_tasks.notify_one();
LOG_INFO << "Stopping background task! ";
if (bgr_thread_.joinable()) {
bgr_thread_.join();
}
LOG_INFO << "Background task stopped! ";
bool LlamaEngine::ShouldInitBackend() const {
// May have race condition here, need to check
for (auto& [_, l] : server_ctx_map_) {
if (l.model_loaded_external)
return false;
}
return true;
}

extern "C" {
Expand Down
Loading

0 comments on commit 2f281bb

Please sign in to comment.