diff --git a/.github/scripts/e2e-test-llama-linux-and-mac.sh b/.github/scripts/e2e-test-llama-linux-and-mac.sh index 5b7b9771d..ab9656619 100644 --- a/.github/scripts/e2e-test-llama-linux-and-mac.sh +++ b/.github/scripts/e2e-test-llama-linux-and-mac.sh @@ -51,6 +51,7 @@ response1=$(curl --connect-timeout 60 -o /tmp/load-llm-model-res.log -s -w "%{ht --header 'Content-Type: application/json' \ --data '{ "llama_model_path": "/tmp/testllm", + "model_alias": "testllm", "ctx_len": 50, "ngl": 32, "embedding": false @@ -73,7 +74,7 @@ response2=$( {"content": "Write a long and sad story for me", "role": "user"} ], "stream": true, - "model": "gpt-3.5-turbo", + "model": "testllm", "max_tokens": 50, "stop": ["hello"], "frequency_penalty": 0, diff --git a/.github/scripts/e2e-test-llama-windows.bat b/.github/scripts/e2e-test-llama-windows.bat index cddca1e0b..fd3a21f09 100644 --- a/.github/scripts/e2e-test-llama-windows.bat +++ b/.github/scripts/e2e-test-llama-windows.bat @@ -63,7 +63,7 @@ rem Define JSON strings for curl data call set "MODEL_LLM_PATH_STRING=%%MODEL_LLM_PATH:\=\\%%" call set "MODEL_EMBEDDING_PATH_STRING=%%MODEL_EMBEDDING_PATH:\=\\%%" set "curl_data1={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}" -set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":false,\"model\":\"gpt-3.5-turbo\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" +set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":true,\"model\":\"testllm\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" set "curl_data3={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}" set "curl_data4={\"llama_model_path\":\"%MODEL_EMBEDDING_PATH_STRING%\", \"embedding\": true, \"model_type\": \"embedding\"}" set "curl_data5={\"input\": \"Hello\", \"model\": \"test-embedding\", \"encoding_format\": \"float\"}" diff --git a/context/llama_server_context.h b/context/llama_server_context.h index 21792f11b..deee1d948 100644 --- a/context/llama_server_context.h +++ b/context/llama_server_context.h @@ -504,6 +504,7 @@ struct llama_server_context { std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; + std::thread bgr_thread; ModelType model_type = ModelType::LLM; ~llama_server_context() { @@ -515,6 +516,7 @@ struct llama_server_context { llama_free_model(model); model = nullptr; } + release_resources(); } bool load_model(const gpt_params& params_) { @@ -603,6 +605,10 @@ struct llama_server_context { // empty system prompt system_prompt = ""; system_tokens.clear(); + + model_loaded_external = true; + LOG_INFO << "Started background task here!"; + bgr_thread = std::thread(std::bind(&llama_server_context::do_background_tasks, this)); } std::vector tokenize(const json& json_prompt, @@ -1882,6 +1888,33 @@ struct llama_server_context { } return true; } + + void do_background_tasks() { + while (model_loaded_external) { + update_slots(); + } + LOG_INFO << "Background task stopped! "; + kv_cache_clear(); + LOG_INFO << "KV cache cleared!"; + } + + void release_resources() { + if(model_loaded_external) { + LOG_INFO << "Releasing llama_server_context resources"; + model_loaded_external = false; + condition_tasks.notify_one(); + + if (bgr_thread.joinable()) { + bgr_thread.join(); + } + + llama_free(ctx); + llama_free_model(model); + ctx = nullptr; + model = nullptr; + LOG_INFO << "Released llama_server_context resources"; + } + } }; static void server_print_usage(const char* argv0, const gpt_params& params, diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 69284d6e9..12ee92c63 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -28,11 +28,11 @@ enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED }; struct inferenceState { int task_id; InferenceStatus inference_status = PENDING; - llamaCPP* instance; + llama_server_context& llama; // Check if we receive the first token, set it to false after receiving bool is_first_token = true; - inferenceState(llamaCPP* inst) : instance(inst) {} + inferenceState(llama_server_context& l) : llama(l) {} }; /** @@ -40,8 +40,9 @@ struct inferenceState { * inferenceState will be persisting even tho the lambda in streaming might go * out of scope and the handler already moved on */ -std::shared_ptr create_inference_state(llamaCPP* instance) { - return std::make_shared(instance); +std::shared_ptr create_inference_state( + llama_server_context& l) { + return std::make_shared(l); } /** @@ -49,9 +50,13 @@ std::shared_ptr create_inference_state(llamaCPP* instance) { * @param callback the function to return message to user */ bool llamaCPP::CheckModelLoaded( - const std::function& callback) { - if (!llama.model_loaded_external) { - LOG_ERROR << "Model has not been loaded"; + const std::function& callback, + const std::string& model_id) { + if (auto l = server_ctx_map.find(model_id); + l == server_ctx_map.end() || !l->second.model_loaded_external) { + LOG_ERROR << "Error: model_id: " << model_id + << ", existed: " << (l == server_ctx_map.end()) + << ", loaded: " << (l != server_ctx_map.end()); Json::Value jsonResp; jsonResp["message"] = "Model has not been loaded, please load model into nitro"; @@ -143,41 +148,39 @@ std::string create_return_json(const std::string& id, const std::string& model, return Json::writeString(writer, root); } -llamaCPP::llamaCPP() - : queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, - "llamaCPP")) { +llamaCPP::llamaCPP() { // Some default values for now below log_disable(); // Disable the log to file feature, reduce bloat for // target // system () }; -llamaCPP::~llamaCPP() { - StopBackgroundTask(); -} - -void llamaCPP::WarmupModel() { - json pseudo; - - LOG_INFO << "Warm-up model"; - pseudo["prompt"] = "Hello"; - pseudo["n_predict"] = 2; - pseudo["stream"] = false; - const int task_id = llama.request_completion(pseudo, false, false, -1); - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - LOG_INFO << result.result_json.dump(-1, ' ', false, - json::error_handler_t::replace); +llamaCPP::~llamaCPP() {} + +void llamaCPP::WarmupModel(const std::string& model_id) { + if (auto l = server_ctx_map.find(model_id); l != server_ctx_map.end()) { + json pseudo; + + LOG_INFO << "Warm-up model: " << model_id; + pseudo["prompt"] = "Hello"; + pseudo["n_predict"] = 2; + pseudo["stream"] = false; + const int task_id = l->second.request_completion(pseudo, false, false, -1); + task_result result = l->second.next_result(task_id); + if (!result.error && result.stop) { + LOG_INFO << result.result_json.dump(-1, ' ', false, + json::error_handler_t::replace); + } + } else { + LOG_WARN << "Model not found " << model_id; } - return; } void llamaCPP::ChatCompletion( inferences::ChatCompletionRequest&& completion, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, completion.model_id)) { // Model is loaded // Do Inference InferenceImpl(std::move(completion), std::move(callback)); @@ -187,7 +190,10 @@ void llamaCPP::ChatCompletion( void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, std::function&& callback) { - if (llama.model_type == ModelType::EMBEDDING) { + assert(server_ctx_map.find(completion.model_id) != server_ctx_map.end()); + auto& l = server_ctx_map[completion.model_id]; + + if (l.model_type == ModelType::EMBEDDING) { LOG_WARN << "Not support completion for embedding model"; Json::Value jsonResp; jsonResp["message"] = "Not support completion for embedding model"; @@ -196,6 +202,7 @@ void llamaCPP::InferenceImpl( callback(resp); return; } + std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -235,7 +242,7 @@ void llamaCPP::InferenceImpl( data["grammar"] = grammar_file_content; }; - if (!llama.multimodal) { + if (!l.multimodal) { for (const auto& message : messages) { std::string input_role = message["role"].asString(); std::string role; @@ -341,7 +348,8 @@ void llamaCPP::InferenceImpl( if (is_streamed) { LOG_INFO_REQUEST(request_id) << "Streamed, waiting for respone"; - auto state = create_inference_state(this); + + auto state = create_inference_state(l); auto chunked_content_provider = [state, data, request_id]( char* pBuffer, @@ -375,7 +383,7 @@ void llamaCPP::InferenceImpl( return nRead; } - task_result result = state->instance->llama.next_result(state->task_id); + task_result result = state->llama.next_result(state->task_id); if (!result.error) { std::string to_send = result.result_json["content"]; @@ -413,11 +421,10 @@ void llamaCPP::InferenceImpl( return 0; }; // Queued task - state->instance->queue->runTaskInQueue([cb = std::move(callback), state, - data, chunked_content_provider, - request_id]() { - state->task_id = - state->instance->llama.request_completion(data, false, false, -1); + ifr_task_queue_map[completion.model_id]->runTaskInQueue([cb = std::move(callback), state, data, + chunked_content_provider, + request_id]() { + state->task_id = state->llama.request_completion(data, false, false, -1); // Start streaming response auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, @@ -429,7 +436,7 @@ void llamaCPP::InferenceImpl( // Since this is an async task, we will wait for the task to be // completed while (state->inference_status != FINISHED && retries < 10 && - state->instance->llama.model_loaded_external) { + state->llama.model_loaded_external) { // Should wait chunked_content_provider lambda to be called within // 3s if (state->inference_status == PENDING) { @@ -442,35 +449,37 @@ void llamaCPP::InferenceImpl( } LOG_INFO_REQUEST(request_id) << "Task completed, release it"; // Request completed, release it - state->instance->llama.request_cancel(state->task_id); + state->llama.request_cancel(state->task_id); LOG_INFO_REQUEST(request_id) << "Inference completed"; }); } else { - queue->runTaskInQueue( - [this, request_id, cb = std::move(callback), d = std::move(data)]() { - Json::Value respData; - int task_id = llama.request_completion(d, false, false, -1); - LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; - if (!json_value(d, "stream", false)) { - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - int prompt_tokens = result.result_json["tokens_evaluated"]; - int predicted_tokens = result.result_json["tokens_predicted"]; - std::string to_send = result.result_json["content"]; - nitro_utils::ltrim(to_send); - respData = create_full_return_json( - nitro_utils::generate_random_string(20), "_", to_send, "_", - prompt_tokens, predicted_tokens); - } else { - respData["message"] = "Internal error during inference"; - LOG_ERROR_REQUEST(request_id) << "Error during inference"; - } - auto resp = nitro_utils::nitroHttpJsonResponse(respData); - cb(resp); - LOG_INFO_REQUEST(request_id) << "Inference completed"; - } - }); + auto state = create_inference_state(l); + ifr_task_queue_map[completion.model_id]->runTaskInQueue([this, request_id, state, + cb = std::move(callback), + d = std::move(data)]() { + Json::Value respData; + int task_id = state->llama.request_completion(d, false, false, -1); + LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; + if (!json_value(d, "stream", false)) { + std::string completion_text; + task_result result = state->llama.next_result(task_id); + if (!result.error && result.stop) { + int prompt_tokens = result.result_json["tokens_evaluated"]; + int predicted_tokens = result.result_json["tokens_predicted"]; + std::string to_send = result.result_json["content"]; + nitro_utils::ltrim(to_send); + respData = create_full_return_json( + nitro_utils::generate_random_string(20), "_", to_send, "_", + prompt_tokens, predicted_tokens); + } else { + respData["message"] = "Internal error during inference"; + LOG_ERROR_REQUEST(request_id) << "Error during inference"; + } + auto resp = nitro_utils::nitroHttpJsonResponse(respData); + cb(resp); + LOG_INFO_REQUEST(request_id) << "Inference completed"; + } + }); } } @@ -478,7 +487,7 @@ void llamaCPP::Embedding( const HttpRequestPtr& req, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, nitro_utils::getModelId(req))) { // Model is loaded const auto& jsonBody = req->getJsonObject(); // Run embedding @@ -490,32 +499,35 @@ void llamaCPP::Embedding( void llamaCPP::EmbeddingImpl( std::shared_ptr jsonBody, std::function&& callback) { + auto model_id = nitro_utils::getModelId(*jsonBody); + assert(server_ctx_map.find(model_id) != server_ctx_map.end()); + int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request"; // Queue embedding task - auto state = create_inference_state(this); + auto state = create_inference_state(server_ctx_map[model_id]); - state->instance->queue->runTaskInQueue([this, state, jsonBody, callback, - request_id]() { + ifr_task_queue_map[model_id]->runTaskInQueue([this, state, jsonBody, callback, + request_id]() { Json::Value responseData(Json::arrayValue); if (jsonBody->isMember("input")) { const Json::Value& input = (*jsonBody)["input"]; if (input.isString()) { // Process the single string input - state->task_id = llama.request_completion( + state->task_id = state->llama.request_completion( {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(state->task_id); + task_result result = state->llama.next_result(state->task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(create_embedding_payload(embedding_result, 0)); } else if (input.isArray()) { // Process each element in the array input for (const auto& elem : input) { if (elem.isString()) { - const int task_id = llama.request_completion( + const int task_id = state->llama.request_completion( {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(task_id); + task_result result = state->llama.next_result(task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(create_embedding_payload(embedding_result, 0)); @@ -526,7 +538,7 @@ void llamaCPP::EmbeddingImpl( Json::Value root; root["data"] = responseData; - root["model"] = "_"; + root["model"] = nitro_utils::getModelId(*jsonBody); root["object"] = "list"; Json::Value usage; usage["prompt_tokens"] = 0; @@ -543,16 +555,15 @@ void llamaCPP::UnloadModel( const HttpRequestPtr& req, std::function&& callback) { Json::Value jsonResp; - if (CheckModelLoaded(callback)) { - StopBackgroundTask(); - - llama_free(llama.ctx); - llama_free_model(llama.model); - llama.ctx = nullptr; - llama.model = nullptr; + auto model_id = nitro_utils::getModelId(req); + if (CheckModelLoaded(callback, model_id)) { + auto& l = server_ctx_map[model_id]; + l.release_resources(); jsonResp["message"] = "Model unloaded successfully"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); + server_ctx_map.erase(model_id); + ifr_task_queue_map.erase(model_id); LOG_INFO << "Model unloaded successfully"; } } @@ -560,11 +571,14 @@ void llamaCPP::UnloadModel( void llamaCPP::ModelStatus( const HttpRequestPtr& req, std::function&& callback) { - Json::Value jsonResp; - bool is_model_loaded = llama.model_loaded_external; - if (CheckModelLoaded(callback)) { - jsonResp["model_loaded"] = is_model_loaded; - jsonResp["model_data"] = llama.get_model_props().dump(); + + auto model_id = nitro_utils::getModelId(req); + if (auto is_loaded = CheckModelLoaded(callback, model_id); is_loaded) { + // CheckModelLoaded gurantees that model_id exists in server_ctx_map; + auto l = server_ctx_map.find(model_id); + Json::Value jsonResp; + jsonResp["model_loaded"] = is_loaded; + jsonResp["model_data"] = l->second.get_model_props().dump(); auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); LOG_INFO << "Model status responded"; @@ -574,7 +588,6 @@ void llamaCPP::ModelStatus( void llamaCPP::LoadModel( const HttpRequestPtr& req, std::function&& callback) { - if (!nitro_utils::isAVX2Supported() && ggml_cpu_has_avx2()) { LOG_ERROR << "AVX2 is not supported by your processor"; Json::Value jsonResp; @@ -587,8 +600,20 @@ void llamaCPP::LoadModel( return; } - if (llama.model_loaded_external) { - LOG_INFO << "Model already loaded"; + auto model_id = nitro_utils::getModelId(req); + if (model_id.empty()) { + LOG_INFO << "Model id is empty in request"; + Json::Value jsonResp; + jsonResp["message"] = "No model id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + if (auto l = server_ctx_map.find(model_id); + l != server_ctx_map.end() && l->second.model_loaded_external) { + LOG_INFO << "Model already loaded: " << model_id; Json::Value jsonResp; jsonResp["message"] = "Model already loaded"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); @@ -611,7 +636,7 @@ void llamaCPP::LoadModel( jsonResp["message"] = "Model loaded successfully"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); - LOG_INFO << "Model loaded successfully"; + LOG_INFO << "Model loaded successfully: " << model_id; } } @@ -663,11 +688,7 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.n_ctx = jsonBody->get("ctx_len", 2048).asInt(); params.embedding = jsonBody->get("embedding", true).asBool(); model_type = jsonBody->get("model_type", "llm").asString(); - if (model_type == "llm") { - llama.model_type = ModelType::LLM; - } else { - llama.model_type = ModelType::EMBEDDING; - } + // Check if n_parallel exists in jsonBody, if not, set to drogon_thread params.n_batch = jsonBody->get("n_batch", 512).asInt(); params.n_parallel = jsonBody->get("n_parallel", 1).asInt(); @@ -696,62 +717,48 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.model_alias = params.model; } - llama_backend_init(); + if (ShouldInitBackend()) { + llama_backend_init(); + LOG_INFO_LLAMA("system info", + { + {"n_threads", params.n_threads}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); + } - // LOG_INFO_LLAMA("build info", - // {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); - LOG_INFO_LLAMA("system info", - { - {"n_threads", params.n_threads}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + auto model_id = nitro_utils::getModelId(*jsonBody); // load the model - if (!llama.load_model(params)) { + if (!server_ctx_map[model_id].load_model(params)) { LOG_ERROR << "Error loading the model"; return false; // Indicate failure } - llama.initialize(); - if (queue != nullptr) { - delete queue; + if (model_type == "llm") { + server_ctx_map[model_id].model_type = ModelType::LLM; + } else { + server_ctx_map[model_id].model_type = ModelType::EMBEDDING; } - queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP"); - - llama.model_loaded_external = true; - - LOG_INFO << "Started background task here!"; - backgroundThread = std::thread(&llamaCPP::BackgroundTask, this); + server_ctx_map[model_id].initialize(); + + ifr_task_queue_map.emplace(model_id, std::make_unique( + params.n_parallel, model_id)); // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models - if (llama.model_type == ModelType::LLM) { - WarmupModel(); + if (server_ctx_map[model_id].model_type == ModelType::LLM) { + WarmupModel(model_id); } return true; } -void llamaCPP::BackgroundTask() { - while (llama.model_loaded_external) { - // model_loaded = - llama.update_slots(); - } - LOG_INFO << "Background task stopped! "; - llama.kv_cache_clear(); - LOG_INFO << "KV cache cleared!"; - return; -} - -void llamaCPP::StopBackgroundTask() { - if (llama.model_loaded_external) { - llama.model_loaded_external = false; - llama.condition_tasks.notify_one(); - LOG_INFO << "Stopping background task! "; - if (backgroundThread.joinable()) { - backgroundThread.join(); - } - LOG_INFO << "Background task stopped! "; +bool llamaCPP::ShouldInitBackend() const { + // May have race condition here, need to check + for (auto& [_, l] : server_ctx_map) { + if (l.model_loaded_external) + return false; } + return true; } diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 531c18b20..73f5f3393 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -75,11 +75,9 @@ class llamaCPP : public drogon::HttpController, std::function&& callback) override; private: - llama_server_context llama; - // std::atomic model_loaded = false; + std::unordered_map server_ctx_map; size_t sent_count = 0; size_t sent_token_probs_index = 0; - std::thread backgroundThread; std::string user_prompt; std::string ai_prompt; std::string system_prompt; @@ -94,7 +92,9 @@ class llamaCPP : public drogon::HttpController, /** * Queue to handle the inference tasks */ - trantor::ConcurrentTaskQueue* queue; + // TODO: should we move this to server_context? + std::unordered_map> + ifr_task_queue_map; bool LoadModelImpl(std::shared_ptr jsonBody); void InferenceImpl(inferences::ChatCompletionRequest&& completion, @@ -102,9 +102,9 @@ class llamaCPP : public drogon::HttpController, void EmbeddingImpl(std::shared_ptr jsonBody, std::function&& callback); bool CheckModelLoaded( - const std::function& callback); - void WarmupModel(); - void BackgroundTask(); - void StopBackgroundTask(); + const std::function& callback, + const std::string& model_id); + void WarmupModel(const std::string& model_id); + bool ShouldInitBackend() const; }; }; // namespace inferences diff --git a/models/chat_completion_request.h b/models/chat_completion_request.h index f4fd087f5..c69faa6ad 100644 --- a/models/chat_completion_request.h +++ b/models/chat_completion_request.h @@ -11,6 +11,7 @@ struct ChatCompletionRequest { float presence_penalty = 0; Json::Value stop = Json::Value(Json::arrayValue); Json::Value messages = Json::Value(Json::arrayValue); + std::string model_id; }; } // namespace inferences @@ -30,6 +31,7 @@ inline inferences::ChatCompletionRequest fromRequest(const HttpRequest& req) { (*jsonBody).get("presence_penalty", 0).asFloat(); completion.messages = (*jsonBody)["messages"]; completion.stop = (*jsonBody)["stop"]; + completion.model_id = (*jsonBody).get("model", {}).asString(); } return completion; } diff --git a/test/components/test_nitro_utils.cc b/test/components/test_nitro_utils.cc index adf3e976b..65f94deee 100644 --- a/test/components/test_nitro_utils.cc +++ b/test/components/test_nitro_utils.cc @@ -1,41 +1,113 @@ #include "gtest/gtest.h" #include "utils/nitro_utils.h" -class NitroUtilTest : public ::testing::Test { -}; +class NitroUtilTest : public ::testing::Test {}; TEST_F(NitroUtilTest, left_trim) { - { - std::string empty; - nitro_utils::ltrim(empty); - EXPECT_EQ(empty, ""); - } - - { - std::string s = "abc"; - std::string expected = "abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " abc"; - std::string expected = "abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = "1 abc 2 "; - std::string expected = "1 abc 2 "; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " |abc"; - std::string expected = "|abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } + { + std::string empty; + nitro_utils::ltrim(empty); + EXPECT_EQ(empty, ""); + } + + { + std::string s = "abc"; + std::string expected = "abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = " abc"; + std::string expected = "abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = "1 abc 2 "; + std::string expected = "1 abc 2 "; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = " |abc"; + std::string expected = "|abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } } + +TEST_F(NitroUtilTest, get_model_id) { + // linux + { + Json::Value data; + data["llama_model_path"] = + "e:/workspace/model/" + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + // windows + { + Json::Value data; + data["llama_model_path"] = + "e:\\workspace\\model\\" + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + std::string expected = "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = ""; + data["model_alias"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + std::string expected = "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = ""; + std::string expected = ""; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + // For embedding request + { + Json::Value data; + data["model"] = + "nomic-embed-text-v1.5.f16"; + std::string expected = "nomic-embed-text-v1.5.f16"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\testllm"; + std::string expected = "testllm"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } +} \ No newline at end of file diff --git a/utils/nitro_utils.h b/utils/nitro_utils.h index c1087b345..68625be34 100644 --- a/utils/nitro_utils.h +++ b/utils/nitro_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -282,4 +283,31 @@ inline void ltrim(std::string& s) { })); }; +inline std::string getModelId(const Json::Value& jsonBody) { + // First check if model exists in request + if(!jsonBody["model"].isNull()) { + return jsonBody["model"].asString(); + } else if(!jsonBody["model_alias"].isNull()) { + return jsonBody["model_alias"].asString(); + } + + // We check llama_model_path for loadmodel request + if (auto input = jsonBody["llama_model_path"]; !input.isNull()) { + auto s = input.asString(); + std::replace(s.begin(), s.end(), '\\', '/'); + auto const pos = s.find_last_of('/'); + // We only truncate the extension if file name has gguf extension + if(s.substr(s.find_last_of('.') + 1) == "gguf") { + return s.substr(pos + 1, s.find_last_of('.') - pos - 1); + } else { + return s.substr(pos + 1); + } + } + return {}; +} + +inline std::string getModelId(const drogon::HttpRequestPtr& req) { + return getModelId(*(req->getJsonObject())); +} + } // namespace nitro_utils