diff --git a/.github/scripts/e2e-test-llama-linux-and-mac.sh b/.github/scripts/e2e-test-llama-linux-and-mac.sh index e97c51f63..b809f5127 100644 --- a/.github/scripts/e2e-test-llama-linux-and-mac.sh +++ b/.github/scripts/e2e-test-llama-linux-and-mac.sh @@ -45,6 +45,7 @@ response1=$(curl --connect-timeout 60 -o /tmp/response1.log -s -w "%{http_code}" --header 'Content-Type: application/json' \ --data '{ "llama_model_path": "/tmp/testllm", + "model_alias": "gpt-3.5-turbo", "ctx_len": 50, "ngl": 32, "embedding": false diff --git a/.github/scripts/e2e-test-llama-windows.bat b/.github/scripts/e2e-test-llama-windows.bat index a6526f358..d0ec31ce3 100644 --- a/.github/scripts/e2e-test-llama-windows.bat +++ b/.github/scripts/e2e-test-llama-windows.bat @@ -53,7 +53,7 @@ if not exist "%MODEL_PATH%" ( rem Define JSON strings for curl data call set "MODEL_PATH_STRING=%%MODEL_PATH:\=\\%%" set "curl_data1={\"llama_model_path\":\"%MODEL_PATH_STRING%\"}" -set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":true,\"model\":\"gpt-3.5-turbo\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" +set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":true,\"model\":\"testllm\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" rem Print the values of curl_data1 and curl_data2 for debugging echo curl_data1=%curl_data1% diff --git a/context/llama_server_context.h b/context/llama_server_context.h index 3839e4b3f..e8ef06638 100644 --- a/context/llama_server_context.h +++ b/context/llama_server_context.h @@ -502,6 +502,7 @@ struct llama_server_context { std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; + std::thread bgr_thread; ~llama_server_context() { if (ctx) { @@ -512,6 +513,7 @@ struct llama_server_context { llama_free_model(model); model = nullptr; } + release_resources(); } bool load_model(const gpt_params& params_) { @@ -600,6 +602,10 @@ struct llama_server_context { // empty system prompt system_prompt = ""; system_tokens.clear(); + + model_loaded_external = true; + LOG_INFO << "Started background task here!"; + bgr_thread = std::thread(std::bind(&llama_server_context::do_background_tasks, this)); } std::vector tokenize(const json& json_prompt, @@ -1879,6 +1885,33 @@ struct llama_server_context { } return true; } + + void do_background_tasks() { + while (model_loaded_external) { + update_slots(); + } + LOG_INFO << "Background task stopped! "; + kv_cache_clear(); + LOG_INFO << "KV cache cleared!"; + } + + void release_resources() { + if(model_loaded_external) { + LOG_INFO << "Releasing llama_server_context resources"; + model_loaded_external = false; + condition_tasks.notify_one(); + + if (bgr_thread.joinable()) { + bgr_thread.join(); + } + + llama_free(ctx); + llama_free_model(model); + ctx = nullptr; + model = nullptr; + LOG_INFO << "Released llama_server_context resources"; + } + } }; static void server_print_usage(const char* argv0, const gpt_params& params, diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 3e3015c2a..17c180d8f 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -28,11 +28,11 @@ enum InferenceStatus { PENDING, RUNNING, EOS, FINISHED }; struct inferenceState { int task_id; InferenceStatus inference_status = PENDING; - llamaCPP* instance; + llama_server_context& llama; // Check if we receive the first token, set it to false after receiving bool is_first_token = true; - inferenceState(llamaCPP* inst) : instance(inst) {} + inferenceState(llama_server_context& l) : llama(l) {} }; /** @@ -40,8 +40,9 @@ struct inferenceState { * inferenceState will be persisting even tho the lambda in streaming might go * out of scope and the handler already moved on */ -std::shared_ptr create_inference_state(llamaCPP* instance) { - return std::make_shared(instance); +std::shared_ptr create_inference_state( + llama_server_context& l) { + return std::make_shared(l); } /** @@ -49,9 +50,13 @@ std::shared_ptr create_inference_state(llamaCPP* instance) { * @param callback the function to return message to user */ bool llamaCPP::CheckModelLoaded( - const std::function& callback) { - if (!llama.model_loaded_external) { - LOG_ERROR << "Model has not been loaded"; + const std::function& callback, + const std::string& model_id) { + if (auto l = server_ctx_map.find(model_id); + l == server_ctx_map.end() || !l->second.model_loaded_external) { + LOG_ERROR << "Error: model_id: " << model_id + << ", existed: " << (l == server_ctx_map.end()) + << ", loaded: " << (l != server_ctx_map.end()); Json::Value jsonResp; jsonResp["message"] = "Model has not been loaded, please load model into nitro"; @@ -143,41 +148,40 @@ std::string create_return_json(const std::string& id, const std::string& model, return Json::writeString(writer, root); } -llamaCPP::llamaCPP() - : queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, - "llamaCPP")) { +llamaCPP::llamaCPP() { // Some default values for now below log_disable(); // Disable the log to file feature, reduce bloat for // target // system () }; -llamaCPP::~llamaCPP() { - StopBackgroundTask(); -} - -void llamaCPP::WarmupModel() { - json pseudo; - - LOG_INFO << "Warm-up model"; - pseudo["prompt"] = "Hello"; - pseudo["n_predict"] = 2; - pseudo["stream"] = false; - const int task_id = llama.request_completion(pseudo, false, false, -1); - std::string completion_text; - task_result result = llama.next_result(task_id); - if (!result.error && result.stop) { - LOG_INFO << result.result_json.dump(-1, ' ', false, - json::error_handler_t::replace); +llamaCPP::~llamaCPP() {} + +void llamaCPP::WarmupModel(const std::string& model_id) { + if (auto l = server_ctx_map.find(model_id); l != server_ctx_map.end()) { + json pseudo; + + LOG_INFO << "Warm-up model: " << model_id; + pseudo["prompt"] = "Hello"; + pseudo["n_predict"] = 2; + pseudo["stream"] = false; + const int task_id = + l->second.request_completion(pseudo, false, false, -1); + task_result result = l->second.next_result(task_id); + if (!result.error && result.stop) { + LOG_INFO << result.result_json.dump(-1, ' ', false, + json::error_handler_t::replace); + } + } else { + LOG_WARN << "Model not found " << model_id; } - return; } void llamaCPP::ChatCompletion( inferences::ChatCompletionRequest&& completion, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, completion.model_id)) { // Model is loaded // Do Inference InferenceImpl(std::move(completion), std::move(callback)); @@ -187,6 +191,9 @@ void llamaCPP::ChatCompletion( void llamaCPP::InferenceImpl( inferences::ChatCompletionRequest&& completion, std::function&& callback) { + assert(server_ctx_map.find(completion.model_id) != server_ctx_map.end()); + auto& l = server_ctx_map[completion.model_id]; + std::string formatted_output = pre_prompt; int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request"; @@ -226,7 +233,7 @@ void llamaCPP::InferenceImpl( data["grammar"] = grammar_file_content; }; - if (!llama.multimodal) { + if (!l.multimodal) { for (const auto& message : messages) { std::string input_role = message["role"].asString(); std::string role; @@ -332,7 +339,8 @@ void llamaCPP::InferenceImpl( if (is_streamed) { LOG_INFO_REQUEST(request_id) << "Streamed, waiting for respone"; - auto state = create_inference_state(this); + + auto state = create_inference_state(l); auto chunked_content_provider = [state, data, request_id]( char* pBuffer, @@ -366,7 +374,7 @@ void llamaCPP::InferenceImpl( return nRead; } - task_result result = state->instance->llama.next_result(state->task_id); + task_result result = state->llama.next_result(state->task_id); if (!result.error) { std::string to_send = result.result_json["content"]; @@ -404,11 +412,10 @@ void llamaCPP::InferenceImpl( return 0; }; // Queued task - state->instance->queue->runTaskInQueue([cb = std::move(callback), state, - data, chunked_content_provider, - request_id]() { - state->task_id = - state->instance->llama.request_completion(data, false, false, -1); + inference_task_queue->runTaskInQueue( + [cb = std::move(callback), state, data, chunked_content_provider, request_id]() { + state->task_id = + state->llama.request_completion(data, false, false, -1); // Start streaming response auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider, @@ -417,33 +424,35 @@ void llamaCPP::InferenceImpl( int retries = 0; - // Since this is an async task, we will wait for the task to be - // completed - while (state->inference_status != FINISHED && retries < 10) { - // Should wait chunked_content_provider lambda to be called within - // 3s - if (state->inference_status == PENDING) { - retries += 1; - } - if (state->inference_status != RUNNING) - LOG_INFO_REQUEST(request_id) - << "Wait for task to be released:" << state->task_id; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - LOG_INFO_REQUEST(request_id) << "Task completed, release it"; - // Request completed, release it - state->instance->llama.request_cancel(state->task_id); - LOG_INFO_REQUEST(request_id) << "Inference completed"; - }); + // Since this is an async task, we will wait for the task to be + // completed + while (state->inference_status != FINISHED && retries < 10) { + // Should wait chunked_content_provider lambda to be called within + // 3s + if (state->inference_status == PENDING) { + retries += 1; + } + if (state->inference_status != RUNNING) + LOG_INFO_REQUEST(request_id) + << "Wait for task to be released:" << state->task_id; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + LOG_INFO_REQUEST(request_id) << "Task completed, release it"; + // Request completed, release it + state->llama.request_cancel(state->task_id); + LOG_INFO_REQUEST(request_id) << "Inference completed"; + }); + } else { - queue->runTaskInQueue( - [this, request_id, cb = std::move(callback), d = std::move(data)]() { + auto state = create_inference_state(l); + inference_task_queue->runTaskInQueue( + [this, request_id, state, cb = std::move(callback), d = std::move(data)]() { Json::Value respData; - int task_id = llama.request_completion(d, false, false, -1); + int task_id = state->llama.request_completion(d, false, false, -1); LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone"; if (!json_value(d, "stream", false)) { std::string completion_text; - task_result result = llama.next_result(task_id); + task_result result = state->llama.next_result(task_id); if (!result.error && result.stop) { int prompt_tokens = result.result_json["tokens_evaluated"]; int predicted_tokens = result.result_json["tokens_predicted"]; @@ -468,7 +477,7 @@ void llamaCPP::Embedding( const HttpRequestPtr& req, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, nitro_utils::getModelId(req))) { // Model is loaded const auto& jsonBody = req->getJsonObject(); // Run embedding @@ -480,32 +489,35 @@ void llamaCPP::Embedding( void llamaCPP::EmbeddingImpl( std::shared_ptr jsonBody, std::function&& callback) { + auto model_id = nitro_utils::getModelId(*jsonBody); + assert(server_ctx_map.find(model_id) != server_ctx_map.end()); + int request_id = ++no_of_requests; LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request"; // Queue embedding task - auto state = create_inference_state(this); + auto state = create_inference_state(server_ctx_map[model_id]); - state->instance->queue->runTaskInQueue([this, state, jsonBody, callback, - request_id]() { + inference_task_queue->runTaskInQueue([this, state, jsonBody, callback, + request_id]() { Json::Value responseData(Json::arrayValue); if (jsonBody->isMember("input")) { const Json::Value& input = (*jsonBody)["input"]; if (input.isString()) { // Process the single string input - state->task_id = llama.request_completion( + state->task_id = state->llama.request_completion( {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(state->task_id); + task_result result = state->llama.next_result(state->task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(create_embedding_payload(embedding_result, 0)); } else if (input.isArray()) { // Process each element in the array input for (const auto& elem : input) { if (elem.isString()) { - const int task_id = llama.request_completion( + const int task_id = state->llama.request_completion( {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1); - task_result result = llama.next_result(task_id); + task_result result = state->llama.next_result(task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(create_embedding_payload(embedding_result, 0)); @@ -516,7 +528,7 @@ void llamaCPP::EmbeddingImpl( Json::Value root; root["data"] = responseData; - root["model"] = "_"; + root["model"] = nitro_utils::getModelId(*jsonBody); root["object"] = "list"; Json::Value usage; usage["prompt_tokens"] = 0; @@ -533,16 +545,14 @@ void llamaCPP::UnloadModel( const HttpRequestPtr& req, std::function&& callback) { Json::Value jsonResp; - if (CheckModelLoaded(callback)) { - StopBackgroundTask(); - - llama_free(llama.ctx); - llama_free_model(llama.model); - llama.ctx = nullptr; - llama.model = nullptr; + auto model_id = nitro_utils::getModelId(req); + if (CheckModelLoaded(callback, model_id)) { + auto& l = server_ctx_map[model_id]; + l.release_resources(); jsonResp["message"] = "Model unloaded successfully"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); + server_ctx_map.erase(model_id); LOG_INFO << "Model unloaded successfully"; } } @@ -550,11 +560,14 @@ void llamaCPP::UnloadModel( void llamaCPP::ModelStatus( const HttpRequestPtr& req, std::function&& callback) { - Json::Value jsonResp; - bool is_model_loaded = llama.model_loaded_external; - if (CheckModelLoaded(callback)) { - jsonResp["model_loaded"] = is_model_loaded; - jsonResp["model_data"] = llama.get_model_props().dump(); + + auto model_id = nitro_utils::getModelId(req); + if (auto is_loaded = CheckModelLoaded(callback, model_id); is_loaded) { + // CheckModelLoaded gurantees that model_id exists in server_ctx_map; + auto l = server_ctx_map.find(model_id); + Json::Value jsonResp; + jsonResp["model_loaded"] = is_loaded; + jsonResp["model_data"] = l->second.get_model_props().dump(); auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); LOG_INFO << "Model status responded"; @@ -564,7 +577,6 @@ void llamaCPP::ModelStatus( void llamaCPP::LoadModel( const HttpRequestPtr& req, std::function&& callback) { - if (!nitro_utils::isAVX2Supported() && ggml_cpu_has_avx2()) { LOG_ERROR << "AVX2 is not supported by your processor"; Json::Value jsonResp; @@ -577,8 +589,20 @@ void llamaCPP::LoadModel( return; } - if (llama.model_loaded_external) { - LOG_INFO << "Model already loaded"; + auto model_id = nitro_utils::getModelId(req); + if(model_id.empty()) { + LOG_INFO << "Model id is empty in request"; + Json::Value jsonResp; + jsonResp["message"] = "No model id found in request body"; + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + if (auto l = server_ctx_map.find(model_id); + l != server_ctx_map.end() && l->second.model_loaded_external) { + LOG_INFO << "Model already loaded: " << model_id; Json::Value jsonResp; jsonResp["message"] = "Model already loaded"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); @@ -601,7 +625,7 @@ void llamaCPP::LoadModel( jsonResp["message"] = "Model loaded successfully"; auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); callback(resp); - LOG_INFO << "Model loaded successfully"; + LOG_INFO << "Model loaded successfully: " << model_id; } } @@ -681,62 +705,47 @@ bool llamaCPP::LoadModelImpl(std::shared_ptr jsonBody) { params.model_alias = params.model; } - llama_backend_init(); - - // LOG_INFO_LLAMA("build info", - // {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); - LOG_INFO_LLAMA("system info", - { - {"n_threads", params.n_threads}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + if (ShouldInitBackend()) { + llama_backend_init(); + LOG_INFO_LLAMA("system info", + { + {"n_threads", params.n_threads}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); + } + auto model_id = nitro_utils::getModelId(*jsonBody); // load the model - if (!llama.load_model(params)) { + if (!server_ctx_map[model_id].load_model(params)) { LOG_ERROR << "Error loading the model"; return false; // Indicate failure } - llama.initialize(); - - if (queue != nullptr) { - delete queue; + server_ctx_map[model_id].initialize(); + + if (inference_task_queue == nullptr || + task_queue_thread_num < params.n_parallel) { + task_queue_thread_num = std::max(task_queue_thread_num, params.n_parallel); + LOG_INFO << "Start inference task queue, num threads: " + << task_queue_thread_num; + inference_task_queue = std::make_unique( + task_queue_thread_num, "llamaCPP"); } - queue = new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP"); - - llama.model_loaded_external = true; - - LOG_INFO << "Started background task here!"; - backgroundThread = std::thread(&llamaCPP::BackgroundTask, this); - // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models // TODO: in case embedded model only, we should reject completion request from user? if (model_type == "llm") { - WarmupModel(); + WarmupModel(model_id); } return true; } -void llamaCPP::BackgroundTask() { - while (llama.model_loaded_external) { - // model_loaded = - llama.update_slots(); - } - LOG_INFO << "Background task stopped! "; - llama.kv_cache_clear(); - LOG_INFO << "KV cache cleared!"; - return; -} - -void llamaCPP::StopBackgroundTask() { - if (llama.model_loaded_external) { - llama.model_loaded_external = false; - llama.condition_tasks.notify_one(); - LOG_INFO << "Background task stopped! "; - if (backgroundThread.joinable()) { - backgroundThread.join(); - } +bool llamaCPP::ShouldInitBackend() const { + // May have race condition here, need to check + for (auto& [_, l] : server_ctx_map) { + if (l.model_loaded_external) + return false; } + return true; } diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 900786c79..96632a65e 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -59,13 +59,14 @@ class llamaCPP : public drogon::HttpController, // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END void ChatCompletion( - inferences::ChatCompletionRequest &&completion, + inferences::ChatCompletionRequest&& completion, std::function&& callback) override; void Embedding( const HttpRequestPtr& req, std::function&& callback) override; - void LoadModel(const HttpRequestPtr& req, - std::function&& callback) override; + void LoadModel( + const HttpRequestPtr& req, + std::function&& callback) override; void UnloadModel( const HttpRequestPtr& req, std::function&& callback) override; @@ -74,11 +75,9 @@ class llamaCPP : public drogon::HttpController, std::function&& callback) override; private: - llama_server_context llama; - // std::atomic model_loaded = false; + std::unordered_map server_ctx_map; size_t sent_count = 0; size_t sent_token_probs_index = 0; - std::thread backgroundThread; std::string user_prompt; std::string ai_prompt; std::string system_prompt; @@ -93,16 +92,18 @@ class llamaCPP : public drogon::HttpController, /** * Queue to handle the inference tasks */ - trantor::ConcurrentTaskQueue* queue; + int task_queue_thread_num = 1; + std::unique_ptr inference_task_queue; bool LoadModelImpl(std::shared_ptr jsonBody); void InferenceImpl(inferences::ChatCompletionRequest&& completion, std::function&& callback); void EmbeddingImpl(std::shared_ptr jsonBody, std::function&& callback); - bool CheckModelLoaded(const std::function& callback); - void WarmupModel(); - void BackgroundTask(); - void StopBackgroundTask(); + bool CheckModelLoaded( + const std::function& callback, + const std::string& model_id); + void WarmupModel(const std::string& model_id); + bool ShouldInitBackend() const; }; }; // namespace inferences diff --git a/models/chat_completion_request.h b/models/chat_completion_request.h index f4fd087f5..c69faa6ad 100644 --- a/models/chat_completion_request.h +++ b/models/chat_completion_request.h @@ -11,6 +11,7 @@ struct ChatCompletionRequest { float presence_penalty = 0; Json::Value stop = Json::Value(Json::arrayValue); Json::Value messages = Json::Value(Json::arrayValue); + std::string model_id; }; } // namespace inferences @@ -30,6 +31,7 @@ inline inferences::ChatCompletionRequest fromRequest(const HttpRequest& req) { (*jsonBody).get("presence_penalty", 0).asFloat(); completion.messages = (*jsonBody)["messages"]; completion.stop = (*jsonBody)["stop"]; + completion.model_id = (*jsonBody).get("model", {}).asString(); } return completion; } diff --git a/test/components/test_nitro_utils.cc b/test/components/test_nitro_utils.cc index adf3e976b..65f94deee 100644 --- a/test/components/test_nitro_utils.cc +++ b/test/components/test_nitro_utils.cc @@ -1,41 +1,113 @@ #include "gtest/gtest.h" #include "utils/nitro_utils.h" -class NitroUtilTest : public ::testing::Test { -}; +class NitroUtilTest : public ::testing::Test {}; TEST_F(NitroUtilTest, left_trim) { - { - std::string empty; - nitro_utils::ltrim(empty); - EXPECT_EQ(empty, ""); - } - - { - std::string s = "abc"; - std::string expected = "abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " abc"; - std::string expected = "abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = "1 abc 2 "; - std::string expected = "1 abc 2 "; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " |abc"; - std::string expected = "|abc"; - nitro_utils::ltrim(s); - EXPECT_EQ(s, expected); - } + { + std::string empty; + nitro_utils::ltrim(empty); + EXPECT_EQ(empty, ""); + } + + { + std::string s = "abc"; + std::string expected = "abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = " abc"; + std::string expected = "abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = "1 abc 2 "; + std::string expected = "1 abc 2 "; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } + + { + std::string s = " |abc"; + std::string expected = "|abc"; + nitro_utils::ltrim(s); + EXPECT_EQ(s, expected); + } } + +TEST_F(NitroUtilTest, get_model_id) { + // linux + { + Json::Value data; + data["llama_model_path"] = + "e:/workspace/model/" + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + // windows + { + Json::Value data; + data["llama_model_path"] = + "e:\\workspace\\model\\" + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS.gguf"; + std::string expected = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + std::string expected = "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = ""; + data["model_alias"] = + "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + std::string expected = "Starling_Monarch_Westlake_Garten-7B-v0.1-IQ4_XS"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = ""; + std::string expected = ""; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + // For embedding request + { + Json::Value data; + data["model"] = + "nomic-embed-text-v1.5.f16"; + std::string expected = "nomic-embed-text-v1.5.f16"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } + + { + Json::Value data; + data["llama_model_path"] = "C:\\Users\\runneradmin\\AppData\\Local\\Temp\\testllm"; + std::string expected = "testllm"; + EXPECT_EQ(nitro_utils::getModelId(data), expected); + } +} \ No newline at end of file diff --git a/utils/nitro_utils.h b/utils/nitro_utils.h index c1087b345..570297ddf 100644 --- a/utils/nitro_utils.h +++ b/utils/nitro_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -282,4 +283,31 @@ inline void ltrim(std::string& s) { })); }; +inline std::string getModelId(const Json::Value& jsonBody) { + // First check if model exists in request + if(!jsonBody["model"].isNull()) { + return jsonBody["model"].asString(); + } else if(!jsonBody["model_alias"].isNull()) { + return jsonBody["model_alias"].asString(); + } + + // We check llama_model_path for loadmodel request + if (auto input = jsonBody["llama_model_path"]; !input.isNull()) { + auto s = input.asString(); + std::replace(s.begin(), s.end(), '\\', '/'); + auto const pos = s.find_last_of('/'); + // We only if file name has gguf extension or nothing + if(s.substr(s.find_last_of('.') + 1) == "gguf") { + return s.substr(pos + 1, s.find_last_of('.') - pos - 1); + } else { + return s.substr(pos + 1); + } + } + return {}; +} + +inline std::string getModelId(const drogon::HttpRequestPtr& req) { + return getModelId(*(req->getJsonObject())); +} + } // namespace nitro_utils