From bf49809c75dab92b5ed716eb0e12e75112f62eb2 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sat, 11 May 2024 16:21:36 +0700 Subject: [PATCH] feat: multiple models --- .../scripts/e2e-test-server-linux-and-mac.sh | 5 +- .github/scripts/e2e-test-server-windows.bat | 4 +- examples/server/server.cc | 5 +- src/chat_completion_request.h | 2 + src/llama_engine.cc | 303 +-- src/llama_engine.h | 39 +- src/llama_server_context.cc | 2250 +++++++++-------- src/llama_server_context.h | 14 +- src/llama_utils.h | 23 + 9 files changed, 1357 insertions(+), 1288 deletions(-) diff --git a/.github/scripts/e2e-test-server-linux-and-mac.sh b/.github/scripts/e2e-test-server-linux-and-mac.sh index 5397c36..43c8278 100644 --- a/.github/scripts/e2e-test-server-linux-and-mac.sh +++ b/.github/scripts/e2e-test-server-linux-and-mac.sh @@ -51,6 +51,7 @@ response1=$(curl --connect-timeout 60 -o /tmp/load-llm-model-res.log -s -w "%{ht --header 'Content-Type: application/json' \ --data '{ "llama_model_path": "/tmp/testllm", + "model_alias": "testllm", "ctx_len": 50, "ngl": 32, "embedding": false @@ -73,7 +74,7 @@ response2=$( {"content": "Write a long and sad story for me", "role": "user"} ], "stream": true, - "model": "gpt-3.5-turbo", + "model": "testllm", "max_tokens": 50, "stop": ["hello"], "frequency_penalty": 0, @@ -83,7 +84,7 @@ response2=$( ) # unload model -response3=$(curl --connect-timeout 60 -o /tmp/unload-model-res.log --request GET -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/unloadmodel" \ +response3=$(curl --connect-timeout 60 -o /tmp/unload-model-res.log -s -w "%{http_code}" --location "http://127.0.0.1:$PORT/unloadmodel" \ --header 'Content-Type: application/json' \ --data '{ "llama_model_path": "/tmp/testllm" diff --git a/.github/scripts/e2e-test-server-windows.bat b/.github/scripts/e2e-test-server-windows.bat index b4641da..8609f6b 100644 --- a/.github/scripts/e2e-test-server-windows.bat +++ b/.github/scripts/e2e-test-server-windows.bat @@ -63,7 +63,7 @@ rem Define JSON strings for curl data call set "MODEL_LLM_PATH_STRING=%%MODEL_LLM_PATH:\=\\%%" call set "MODEL_EMBEDDING_PATH_STRING=%%MODEL_EMBEDDING_PATH:\=\\%%" set "curl_data1={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}" -set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":false,\"model\":\"gpt-3.5-turbo\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" +set "curl_data2={\"messages\":[{\"content\":\"Hello there\",\"role\":\"assistant\"},{\"content\":\"Write a long and sad story for me\",\"role\":\"user\"}],\"stream\":false,\"model\":\"testllm\",\"max_tokens\":50,\"stop\":[\"hello\"],\"frequency_penalty\":0,\"presence_penalty\":0,\"temperature\":0.1}" set "curl_data3={\"llama_model_path\":\"%MODEL_LLM_PATH_STRING%\"}" set "curl_data4={\"llama_model_path\":\"%MODEL_EMBEDDING_PATH_STRING%\", \"embedding\": true, \"model_type\": \"embedding\"}" set "curl_data5={\"input\": \"Hello\", \"model\": \"test-embedding\", \"encoding_format\": \"float\"}" @@ -82,7 +82,7 @@ curl.exe --connect-timeout 60 -o "%TEMP%\response2.log" -s -w "%%{http_code}" -- --header "Content-Type: application/json" ^ --data "%curl_data2%" > %TEMP%\response2.log 2>&1 -curl.exe --connect-timeout 60 -o "%TEMP%\response3.log" --request GET -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/unloadmodel" --header "Content-Type: application/json" --data "%curl_data3%" > %TEMP%\response3.log 2>&1 +curl.exe --connect-timeout 60 -o "%TEMP%\response3.log" -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/unloadmodel" --header "Content-Type: application/json" --data "%curl_data3%" > %TEMP%\response3.log 2>&1 curl.exe --connect-timeout 60 -o "%TEMP%\response4.log" --request POST -s -w "%%{http_code}" --location "http://127.0.0.1:%PORT%/loadmodel" --header "Content-Type: application/json" --data "%curl_data4%" > %TEMP%\response4.log 2>&1 diff --git a/examples/server/server.cc b/examples/server/server.cc index 364f627..bd17903 100644 --- a/examples/server/server.cc +++ b/examples/server/server.cc @@ -200,10 +200,11 @@ int main(int argc, char** argv) { }; svr->Post("/loadmodel", handle_load_model); - svr->Get("/unloadmodel", handle_unload_model); + // Use POST since httplib does not read request body for GET method + svr->Post("/unloadmodel", handle_unload_model); svr->Post("/v1/chat/completions", handle_completions); svr->Post("/v1/embeddings", handle_embeddings); - svr->Get("/modelstatus", handle_get_model_status); + svr->Post("/modelstatus", handle_get_model_status); LOG_INFO << "HTTP server listening: " << hostname << ":" << port; svr->new_task_queue = [] { diff --git a/src/chat_completion_request.h b/src/chat_completion_request.h index 9855fa0..b78dd66 100644 --- a/src/chat_completion_request.h +++ b/src/chat_completion_request.h @@ -11,6 +11,7 @@ struct ChatCompletionRequest { float presence_penalty = 0; Json::Value stop = Json::Value(Json::arrayValue); Json::Value messages = Json::Value(Json::arrayValue); + std::string model_id; }; inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { @@ -26,6 +27,7 @@ inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { (*jsonBody).get("presence_penalty", 0).asFloat(); completion.messages = (*jsonBody)["messages"]; completion.stop = (*jsonBody)["stop"]; + completion.model_id = (*jsonBody).get("model", {}).asString(); } return completion; } diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 73800a1..c98e757 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -24,13 +24,12 @@ struct InferenceState { * InferenceState will be persisting even tho the lambda in streaming might go * out of scope and the handler already moved on */ -std::shared_ptr CreateInferenceState( - LlamaServerContext& l) { +std::shared_ptr CreateInferenceState(LlamaServerContext& l) { return std::make_shared(l); } Json::Value CreateEmbeddingPayload(const std::vector& embedding, - int prompt_tokens) { + int prompt_tokens) { Json::Value dataItem; dataItem["object"] = "embedding"; @@ -46,11 +45,11 @@ Json::Value CreateEmbeddingPayload(const std::vector& embedding, } Json::Value CreateFullReturnJson(const std::string& id, - const std::string& model, - const std::string& content, - const std::string& system_fingerprint, - int prompt_tokens, int completion_tokens, - Json::Value finish_reason = Json::Value()) { + const std::string& model, + const std::string& content, + const std::string& system_fingerprint, + int prompt_tokens, int completion_tokens, + Json::Value finish_reason = Json::Value()) { Json::Value root; root["id"] = id; @@ -82,8 +81,8 @@ Json::Value CreateFullReturnJson(const std::string& id, } std::string CreateReturnJson(const std::string& id, const std::string& model, - const std::string& content, - Json::Value finish_reason = Json::Value()) { + const std::string& content, + Json::Value finish_reason = Json::Value()) { Json::Value root; root["id"] = id; @@ -114,15 +113,13 @@ LlamaEngine::LlamaEngine() { log_disable(); } -LlamaEngine::~LlamaEngine() { - StopBackgroundTask(); -} +LlamaEngine::~LlamaEngine() {} void LlamaEngine::HandleChatCompletion( std::shared_ptr jsonBody, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, llama_utils::GetModelId(*jsonBody))) { // Model is loaded // Do Inference HandleInferenceImpl(llama::inferences::fromJson(jsonBody), @@ -134,7 +131,7 @@ void LlamaEngine::HandleEmbedding( std::shared_ptr jsonBody, std::function&& callback) { // Check if model is loaded - if (CheckModelLoaded(callback)) { + if (CheckModelLoaded(callback, llama_utils::GetModelId(*jsonBody))) { // Run embedding HandleEmbeddingImpl(jsonBody, std::move(callback)); } @@ -158,7 +155,22 @@ void LlamaEngine::LoadModel( return; } - if (llama_.model_loaded_external) { + auto model_id = llama_utils::GetModelId(*jsonBody); + if (model_id.empty()) { + LOG_INFO << "Model id is empty in request"; + Json::Value jsonResp; + jsonResp["message"] = "No model id found in request body"; + Json::Value status; + status["is_done"] = false; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(jsonResp)); + return; + } + + if (auto si = server_map_.find(model_id); + si != server_map_.end() && si->second.ctx.model_loaded_external) { LOG_INFO << "Model already loaded"; Json::Value jsonResp; jsonResp["message"] = "Model already loaded"; @@ -191,22 +203,18 @@ void LlamaEngine::LoadModel( status["is_stream"] = false; status["status_code"] = k200OK; callback(std::move(status), std::move(jsonResp)); - LOG_INFO << "Model loaded successfully"; + LOG_INFO << "Model loaded successfully: " << model_id; } } void LlamaEngine::UnloadModel( std::shared_ptr jsonBody, std::function&& callback) { + auto model_id = llama_utils::GetModelId(*jsonBody); + if (CheckModelLoaded(callback, model_id)) { + auto& l = server_map_[model_id].ctx; + l.ReleaseResources(); - if (CheckModelLoaded(callback)) { - StopBackgroundTask(); - - llama_free(llama_.ctx); - llama_free_model(llama_.model); - llama_.ctx = nullptr; - llama_.model = nullptr; - llama_backend_free(); Json::Value jsonResp; jsonResp["message"] = "Model unloaded successfully"; Json::Value status; @@ -216,6 +224,7 @@ void LlamaEngine::UnloadModel( status["status_code"] = k200OK; callback(std::move(status), std::move(jsonResp)); + server_map_.erase(model_id); LOG_INFO << "Model unloaded successfully"; } } @@ -224,11 +233,13 @@ void LlamaEngine::GetModelStatus( std::shared_ptr jsonBody, std::function&& callback) { - bool is_model_loaded = llama_.model_loaded_external; - if (CheckModelLoaded(callback)) { + auto model_id = llama_utils::GetModelId(*jsonBody); + if (auto is_loaded = CheckModelLoaded(callback, model_id); is_loaded) { + // CheckModelLoaded gurantees that model_id exists in server_ctx_map; + auto si = server_map_.find(model_id); Json::Value jsonResp; - jsonResp["model_loaded"] = is_model_loaded; - jsonResp["model_data"] = llama_.GetModelProps().dump(); + jsonResp["model_loaded"] = is_loaded; + jsonResp["model_data"] = si->second.ctx.GetModelProps().dump(); Json::Value status; status["is_done"] = true; status["has_error"] = false; @@ -242,6 +253,7 @@ void LlamaEngine::GetModelStatus( bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { gpt_params params; std::string model_type; + auto model_id = llama_utils::GetModelId(*jsonBody); // By default will setting based on number of handlers if (jsonBody) { if (!jsonBody->operator[]("mmproj").isNull()) { @@ -267,13 +279,14 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { } else { std::stringstream grammarBuf; grammarBuf << file.rdbuf(); - grammar_file_content_ = grammarBuf.str(); + server_map_[model_id].grammar_file_content = grammarBuf.str(); } }; Json::Value model_path = jsonBody->operator[]("llama_model_path"); if (model_path.isNull()) { LOG_ERROR << "Missing model path in request"; + //TODO return? } else { if (std::filesystem::exists( std::filesystem::path(model_path.asString()))) { @@ -287,11 +300,6 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { params.n_ctx = jsonBody->get("ctx_len", 2048).asInt(); params.embedding = jsonBody->get("embedding", true).asBool(); model_type = jsonBody->get("model_type", "llm").asString(); - if (model_type == "llm") { - llama_.model_type = ModelType::kLlm; - } else { - llama_.model_type = ModelType::kEmbedding; - } // Check if n_parallel exists in jsonBody, if not, set to drogon_thread params.n_batch = jsonBody->get("n_batch", 512).asInt(); params.n_parallel = jsonBody->get("n_parallel", 1).asInt(); @@ -299,15 +307,18 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { jsonBody->get("cpu_threads", std::thread::hardware_concurrency()) .asInt(); params.cont_batching = jsonBody->get("cont_batching", false).asBool(); - this->clean_cache_threshold_ = - jsonBody->get("clean_cache_threshold", 5).asInt(); - this->caching_enabled_ = jsonBody->get("caching_enabled", false).asBool(); - this->user_prompt_ = jsonBody->get("user_prompt", "USER: ").asString(); - this->ai_prompt_ = jsonBody->get("ai_prompt", "ASSISTANT: ").asString(); - this->system_prompt_ = + server_map_[model_id].caching_enabled = + jsonBody->get("caching_enabled", false).asBool(); + server_map_[model_id].user_prompt = + jsonBody->get("user_prompt", "USER: ").asString(); + server_map_[model_id].ai_prompt = + jsonBody->get("ai_prompt", "ASSISTANT: ").asString(); + server_map_[model_id].system_prompt = jsonBody->get("system_prompt", "ASSISTANT's RULE: ").asString(); - this->pre_prompt_ = jsonBody->get("pre_prompt", "").asString(); - this->repeat_last_n_ = jsonBody->get("repeat_last_n", 32).asInt(); + server_map_[model_id].pre_prompt = + jsonBody->get("pre_prompt", "").asString(); + server_map_[model_id].repeat_last_n = + jsonBody->get("repeat_last_n", 32).asInt(); if (!jsonBody->operator[]("llama_log_folder").isNull()) { log_enable(); @@ -320,36 +331,41 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { params.model_alias = params.model; } - llama_backend_init(); - - // LOG_INFO_LLAMA("build info", - // {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); - LOG_INFO_LLAMA("system info", - { - {"n_threads", params.n_threads}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + if (ShouldInitBackend()) { + llama_backend_init(); + + // LOG_INFO_LLAMA("build info", + // {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); + LOG_INFO_LLAMA("system info", + { + {"n_threads", params.n_threads}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); + } // load the model - if (!llama_.LoadModel(params)) { + if (!server_map_[model_id].ctx.LoadModel(params)) { LOG_ERROR << "Error loading the model"; + // TODO use ScopeExit + server_map_.erase(model_id); return false; // Indicate failure } - llama_.Initialize(); - queue_ = std::make_unique(params.n_parallel, - "llamaCPP"); - - llama_.model_loaded_external = true; + if (model_type == "llm") { + server_map_[model_id].ctx.model_type = ModelType::kLlm; + } else { + server_map_[model_id].ctx.model_type = ModelType::kEmbedding; + } + server_map_[model_id].ctx.Initialize(); - LOG_INFO << "Started background task here!"; - bgr_thread_ = std::thread(&LlamaEngine::HandleBackgroundTask, this); + server_map_[model_id].q = std::make_unique( + params.n_parallel, model_id); // For model like nomic-embed-text-v1.5.f16.gguf, etc, we don't need to warm up model. // So we use this variable to differentiate with other models - if (llama_.model_type == ModelType::kLlm) { - WarmUpModel(); + if (server_map_[model_id].ctx.model_type == ModelType::kLlm) { + WarmUpModel(model_id); } return true; } @@ -357,7 +373,9 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { void LlamaEngine::HandleInferenceImpl( llama::inferences::ChatCompletionRequest&& completion, std::function&& callback) { - if (llama_.model_type == ModelType::kEmbedding) { + assert(server_map_.find(completion.model_id) != server_map_.end()); + auto& si = server_map_[completion.model_id]; + if (si.ctx.model_type == ModelType::kEmbedding) { LOG_WARN << "Not support completion for embedding model"; Json::Value jsonResp; jsonResp["message"] = "Not support completion for embedding model"; @@ -369,9 +387,10 @@ void LlamaEngine::HandleInferenceImpl( callback(std::move(status), std::move(jsonResp)); return; } - std::string formatted_output = pre_prompt_; + std::string formatted_output = si.pre_prompt; int request_id = ++no_of_requests_; - LOG_INFO << "Request " << request_id << ": " + LOG_INFO << "Request " << request_id << ", " << "model " + << completion.model_id << ": " << "Generating reponse for inference request"; json data; @@ -380,11 +399,11 @@ void LlamaEngine::HandleInferenceImpl( // To set default value // Default values to enable auto caching - data["cache_prompt"] = caching_enabled_; + data["cache_prompt"] = si.caching_enabled; data["n_keep"] = 0; // Passing load value - data["repeat_last_n"] = this->repeat_last_n_; + data["repeat_last_n"] = si.repeat_last_n; LOG_INFO << "Request " << request_id << ": " << "Stop words:" << completion.stop.toStyledString(); @@ -396,24 +415,24 @@ void LlamaEngine::HandleInferenceImpl( data["presence_penalty"] = completion.presence_penalty; const Json::Value& messages = completion.messages; - if (!grammar_file_content_.empty()) { - data["grammar"] = grammar_file_content_; + if (!si.grammar_file_content.empty()) { + data["grammar"] = si.grammar_file_content; }; - if (!llama_.multimodal) { + if (!si.ctx.multimodal) { for (const auto& message : messages) { std::string input_role = message["role"].asString(); std::string role; if (input_role == "user") { - role = user_prompt_; + role = si.user_prompt; std::string content = message["content"].asString(); formatted_output += role + content; } else if (input_role == "assistant") { - role = ai_prompt_; + role = si.ai_prompt; std::string content = message["content"].asString(); formatted_output += role + content; } else if (input_role == "system") { - role = system_prompt_; + role = si.system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; @@ -423,7 +442,7 @@ void LlamaEngine::HandleInferenceImpl( formatted_output += role + content; } } - formatted_output += ai_prompt_; + formatted_output += si.ai_prompt; } else { data["image_data"] = json::array(); for (const auto& message : messages) { @@ -432,7 +451,7 @@ void LlamaEngine::HandleInferenceImpl( if (input_role == "user") { formatted_output += role; for (auto content_piece : message["content"]) { - role = user_prompt_; + role = si.user_prompt; json content_piece_image_data; content_piece_image_data["data"] = ""; @@ -471,11 +490,11 @@ void LlamaEngine::HandleInferenceImpl( } } else if (input_role == "assistant") { - role = ai_prompt_; + role = si.ai_prompt; std::string content = message["content"].asString(); formatted_output += role + content; } else if (input_role == "system") { - role = system_prompt_; + role = si.system_prompt; std::string content = message["content"].asString(); formatted_output = role + content + formatted_output; @@ -485,7 +504,7 @@ void LlamaEngine::HandleInferenceImpl( formatted_output += role + content; } } - formatted_output += ai_prompt_; + formatted_output += si.ai_prompt; LOG_INFO << "Request " << request_id << ": " << formatted_output; } @@ -496,25 +515,23 @@ void LlamaEngine::HandleInferenceImpl( // specify default stop words // Ensure success case for chatML stopWords.push_back("<|im_end|>"); - stopWords.push_back(llama_utils::rtrim(user_prompt_)); + stopWords.push_back(llama_utils::rtrim(si.user_prompt)); data["stop"] = stopWords; bool is_streamed = data["stream"]; // Enable full message debugging #ifdef DEBUG - LOG_INFO << "Request " << request_id << ": " - << "Current completion text"; + LOG_INFO << "Request " << request_id << ": " << "Current completion text"; LOG_INFO << "Request " << request_id << ": " << formatted_output; #endif if (is_streamed) { LOG_INFO << "Request " << request_id << ": " << "Streamed, waiting for respone"; - auto state = CreateInferenceState(llama_); + auto state = CreateInferenceState(si.ctx); // Queued task - queue_->runTaskInQueue([cb = std::move(callback), state, data, - request_id]() { + si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id]() { state->task_id = state->llama.RequestCompletion(data, false, false, -1); while (state->llama.model_loaded_external) { TaskResult result = state->llama.NextResult(state->task_id); @@ -528,7 +545,7 @@ void LlamaEngine::HandleInferenceImpl( const std::string str = "data: " + CreateReturnJson(llama_utils::generate_random_string(20), "_", - to_send) + + to_send) + "\n\n"; Json::Value respData; respData["data"] = str; @@ -540,14 +557,13 @@ void LlamaEngine::HandleInferenceImpl( cb(std::move(status), std::move(respData)); if (result.stop) { - LOG_INFO << "Request " << request_id << ": " - << "End of result"; + LOG_INFO << "Request " << request_id << ": " << "End of result"; state->llama.RequestCancel(state->task_id); Json::Value respData; const std::string str = "data: " + CreateReturnJson(llama_utils::generate_random_string(20), "_", - "", "stop") + + "", "stop") + "\n\n" + "data: [DONE]" + "\n\n"; respData["data"] = str; Json::Value status; @@ -588,35 +604,28 @@ void LlamaEngine::HandleInferenceImpl( status["status_code"] = k200OK; cb(std::move(status), std::move(respData)); } - LOG_INFO << "Request " << request_id << ": " - << "Inference completed"; + LOG_INFO << "Request " << request_id << ": " << "Inference completed"; }); } else { - queue_->runTaskInQueue([this, request_id, cb = std::move(callback), - d = std::move(data)]() { + auto state = CreateInferenceState(si.ctx); + si.q->runTaskInQueue([this, request_id, state, cb = std::move(callback), + d = std::move(data)]() { Json::Value respData; - int task_id = llama_.RequestCompletion(d, false, false, -1); + int task_id = state->llama.RequestCompletion(d, false, false, -1); LOG_INFO << "Request " << request_id << ": " << "Non stream, waiting for respone"; if (!json_value(d, "stream", false)) { bool has_error = false; std::string completion_text; - TaskResult result = llama_.NextResult(task_id); + TaskResult result = state->llama.NextResult(task_id); if (!result.error && result.stop) { int prompt_tokens = result.result_json["tokens_evaluated"]; int predicted_tokens = result.result_json["tokens_predicted"]; std::string to_send = result.result_json["content"]; llama_utils::ltrim(to_send); - //https://platform.openai.com/docs/api-reference/chat/object - // finish_reason string - // The reason the model stopped generating tokens. This will be `stop` - // if the model hit a natural stop point or a provided stop sequence, - // `length` if the maximum number of tokens specified in the request was reached, - // `content_filter` if content was omitted due to a flag from our content filters, - // `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. respData = CreateFullReturnJson( llama_utils::generate_random_string(20), "_", to_send, "_", - prompt_tokens, predicted_tokens, "stop" /*finish_reason*/); + prompt_tokens, predicted_tokens); } else { bool has_error = true; respData["message"] = "Internal error during inference"; @@ -630,8 +639,7 @@ void LlamaEngine::HandleInferenceImpl( status["status_code"] = k200OK; cb(std::move(status), std::move(respData)); - LOG_INFO << "Request " << request_id << ": " - << "Inference completed"; + LOG_INFO << "Request " << request_id << ": " << "Inference completed"; } }); } @@ -640,32 +648,35 @@ void LlamaEngine::HandleInferenceImpl( void LlamaEngine::HandleEmbeddingImpl( std::shared_ptr jsonBody, std::function&& callback) { + auto model_id = llama_utils::GetModelId(*jsonBody); + assert(server_map_.find(model_id) != server_map_.end()); int request_id = ++no_of_requests_; - LOG_INFO << "Request " << request_id << ": " + LOG_INFO << "Request " << request_id << ", " << "model " << model_id << ": " << "Generating reponse for embedding request"; // Queue embedding task - auto state = CreateInferenceState(llama_); + auto state = CreateInferenceState(server_map_[model_id].ctx); - queue_->runTaskInQueue([this, state, jsonBody, callback, request_id]() { + server_map_[model_id].q->runTaskInQueue([this, state, jsonBody, callback, + request_id]() { Json::Value responseData(Json::arrayValue); if (jsonBody->isMember("input")) { const Json::Value& input = (*jsonBody)["input"]; if (input.isString()) { // Process the single string input - state->task_id = llama_.RequestCompletion( + state->task_id = state->llama.RequestCompletion( {{"prompt", input.asString()}, {"n_predict", 0}}, false, true, -1); - TaskResult result = llama_.NextResult(state->task_id); + TaskResult result = state->llama.NextResult(state->task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(CreateEmbeddingPayload(embedding_result, 0)); } else if (input.isArray()) { // Process each element in the array input for (const auto& elem : input) { if (elem.isString()) { - const int task_id = llama_.RequestCompletion( + const int task_id = state->llama.RequestCompletion( {{"prompt", elem.asString()}, {"n_predict", 0}}, false, true, -1); - TaskResult result = llama_.NextResult(task_id); + TaskResult result = state->llama.NextResult(task_id); std::vector embedding_result = result.result_json["embedding"]; responseData.append(CreateEmbeddingPayload(embedding_result, 0)); @@ -689,15 +700,18 @@ void LlamaEngine::HandleEmbeddingImpl( status["status_code"] = k200OK; callback(std::move(status), std::move(root)); - LOG_INFO << "Request " << request_id << ": " - << "Embedding completed"; + LOG_INFO << "Request " << request_id << ": " << "Embedding completed"; }); } bool LlamaEngine::CheckModelLoaded( - std::function& callback) { - if (!llama_.model_loaded_external) { - LOG_ERROR << "Model has not been loaded"; + std::function& callback, + const std::string& model_id) { + if (auto si = server_map_.find(model_id); + si == server_map_.end() || !si->second.ctx.model_loaded_external) { + LOG_ERROR << "Error: model_id: " << model_id + << ", existed: " << (si != server_map_.end()) + << ", loaded: " << (si == server_map_.end()); Json::Value jsonResp; jsonResp["message"] = "Model has not been loaded, please load model into nitro"; @@ -712,42 +726,33 @@ bool LlamaEngine::CheckModelLoaded( return true; } -void LlamaEngine::WarmUpModel() { - json pseudo; - - LOG_INFO << "Warm-up model"; - pseudo["prompt"] = "Hello"; - pseudo["n_predict"] = 2; - pseudo["stream"] = false; - const int task_id = llama_.RequestCompletion(pseudo, false, false, -1); - std::string completion_text; - TaskResult result = llama_.NextResult(task_id); - if (!result.error && result.stop) { - LOG_INFO << result.result_json.dump(-1, ' ', false, - json::error_handler_t::replace); - } -} - -void LlamaEngine::HandleBackgroundTask() { - while (llama_.model_loaded_external) { - // model_loaded = - llama_.UpdateSlots(); +void LlamaEngine::WarmUpModel(const std::string& model_id) { + if (auto si = server_map_.find(model_id); si != server_map_.end()) { + json pseudo; + + LOG_INFO << "Warm-up model: " << model_id; + pseudo["prompt"] = "Hello"; + pseudo["n_predict"] = 2; + pseudo["stream"] = false; + const int task_id = + si->second.ctx.RequestCompletion(pseudo, false, false, -1); + TaskResult result = si->second.ctx.NextResult(task_id); + if (!result.error && result.stop) { + LOG_INFO << result.result_json.dump(-1, ' ', false, + json::error_handler_t::replace); + } + } else { + LOG_WARN << "Model not found " << model_id; } - LOG_INFO << "Background task stopped! "; - llama_.KvCacheClear(); - LOG_INFO << "KV cache cleared!"; } -void LlamaEngine::StopBackgroundTask() { - if (llama_.model_loaded_external) { - llama_.model_loaded_external = false; - llama_.condition_tasks.notify_one(); - LOG_INFO << "Stopping background task! "; - if (bgr_thread_.joinable()) { - bgr_thread_.join(); - } - LOG_INFO << "Background task stopped! "; +bool LlamaEngine::ShouldInitBackend() const { + // May have race condition here, need to check + for (auto& [_, l] : server_map_) { + if (l.ctx.model_loaded_external) + return false; } + return true; } extern "C" { diff --git a/src/llama_engine.h b/src/llama_engine.h index e878530..acb61c8 100644 --- a/src/llama_engine.h +++ b/src/llama_engine.h @@ -1,13 +1,13 @@ #pragma once +#include "chat_completion_request.h" #include "cortex-common/enginei.h" #include "llama_server_context.h" #include "trantor/utils/ConcurrentTaskQueue.h" -#include "chat_completion_request.h" class LlamaEngine : public EngineI { public: - LlamaEngine(); - ~LlamaEngine() final; + LlamaEngine(); + ~LlamaEngine() final; // #### Interface #### void HandleChatCompletion( std::shared_ptr jsonBody, @@ -34,24 +34,27 @@ class LlamaEngine : public EngineI { std::shared_ptr jsonBody, std::function&& callback); bool CheckModelLoaded( - std::function& callback); - void WarmUpModel(); - void HandleBackgroundTask(); - void StopBackgroundTask(); + std::function& callback, + const std::string& model_id); + void WarmUpModel(const std::string& model_id); + bool ShouldInitBackend() const; private: - LlamaServerContext llama_; - std::unique_ptr queue_; - std::thread bgr_thread_; + struct ServerInfo { + LlamaServerContext ctx; + std::unique_ptr q; + std::string user_prompt; + std::string ai_prompt; + std::string system_prompt; + std::string pre_prompt; + int repeat_last_n; + bool caching_enabled; + std::string grammar_file_content; + }; + + // key: model_id, value: ServerInfo + std::unordered_map server_map_; - std::string user_prompt_; - std::string ai_prompt_; - std::string system_prompt_; - std::string pre_prompt_; - int repeat_last_n_; - bool caching_enabled_; std::atomic no_of_requests_ = 0; std::atomic no_of_chats_ = 0; - int clean_cache_threshold_; - std::string grammar_file_content_; }; \ No newline at end of file diff --git a/src/llama_server_context.cc b/src/llama_server_context.cc index 34a0fb1..ba12da8 100644 --- a/src/llama_server_context.cc +++ b/src/llama_server_context.cc @@ -230,6 +230,11 @@ void LlamaServerContext::Initialize() { // empty system prompt system_prompt = ""; system_tokens.clear(); + + model_loaded_external = true; + LOG_INFO << "Started background task here!"; + bgr_thread = + std::thread(std::bind(&LlamaServerContext::DoBackgroundTasks, this)); } void LlamaServerContext::KvCacheClear() { @@ -307,1313 +312,1338 @@ void LlamaServerContext::RequestCancel(int task_id) { condition_tasks.notify_one(); } -bool LlamaServerContext::UpdateSlots() { - // attend tasks - ProcessTasks(); - - // update the system prompt wait until all slots are idle state - if (system_need_update && all_slots_are_idle) { - LOG_DEBUG << "updating system prompt"; - UpdateSystemPrompt(); - } - - llama_batch_clear(batch); +void LlamaServerContext::ReleaseResources() { + if (model_loaded_external) { + LOG_INFO << "Releasing llama_server_context resources"; + model_loaded_external = false; + condition_tasks.notify_one(); - if (all_slots_are_idle) { - if (system_prompt.empty() && clean_kv_cache) { - LOG_DEBUG - << "all slots are idle and system prompt is empty, clear the KV " - "cache"; - KvCacheClear(); + if (bgr_thread.joinable()) { + bgr_thread.join(); } - // std::this_thread::sleep_for(std::chrono::milliseconds(5)); - // TODO: Need to implement queueing using CV for better performance - std::unique_lock lock(mutex_tasks); - condition_tasks.wait(lock, [&] { - return (!queue_tasks.empty() && model_loaded_external) || - (queue_tasks.empty() && !model_loaded_external); - }); - } - for (LlamaClientSlot& slot : slots) { - if (slot.IsProcessing() && - (int)system_tokens.size() + slot.n_past >= slot.n_ctx) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; - const int n_discard = n_left / 2; + llama_free(ctx); + llama_free_model(model); + ctx = nullptr; + model = nullptr; + LOG_INFO << "Released llama_server_context resources"; + } +} - LOG_DEBUG << "slot " << slot.id - << " context shift - n_keep = " << slot.params.n_keep - << ", n_left = " << n_left << ", n_discard: " << n_discard - << ", n_ctx = " << n_ctx << ", n_past = " << slot.n_past - << ", n_system_tokens = " << system_tokens.size() - << ", n_cache_tokens = " << slot.cache_tokens.size(); +std::vector LlamaServerContext::Tokenize(const json& json_prompt, + bool add_bos) const { + // TODO: currently, we tokenize using special tokens by default + // this is not always correct (see + // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) + // but it's better compared to completely ignoring ChatML and other + // chat templates + const bool TMP_FORCE_SPECIAL = true; - llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, - slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, slot.id, slot.params.n_keep + 1 + n_discard, - slot.n_past, -n_discard); + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector prompt_tokens; - if (slot.params.cache_prompt) { - for (size_t i = slot.params.n_keep + 1 + n_discard; - i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + if (json_prompt.is_array()) { + bool first = true; + for (const auto& p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + std::vector p; + if (first) { + p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); + first = false; + } else { + p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); } - - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + prompt_tokens.push_back(p.template get()); } - - slot.n_past -= n_discard; - - slot.truncated = true; } + } else { + auto s = json_prompt.template get(); + prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); } - // decode any currently ongoing sequences - for (auto& slot : slots) { - // release the slot - if (slot.command == SlotCommand::kRelease) { - slot.state = SlotState::kIdle; - slot.command = SlotCommand::kNone; - slot.t_last_used = ggml_time_us(); + return prompt_tokens; +} - LOG_INFO << "slot released: " - << "id_slot: " << slot.id << ", id_task: " << slot.task_id - << ", n_ctx: " << n_ctx << ", n_past: " << slot.n_past - << ", n_system_tokens: " << system_tokens.size() - << ", n_cache_tokens: " << slot.cache_tokens.size() - << ", truncated: " << slot.truncated; +LlamaClientSlot* LlamaServerContext::GetSlot(int id) { + int64_t t_last = ggml_time_us(); + LlamaClientSlot* last_used = nullptr; - continue; + for (LlamaClientSlot& slot : slots) { + if (slot.id == id && slot.Available()) { + return &slot; } - if (slot.state == SlotState::kIdle) { - continue; + if (slot.Available() && slot.t_last_used < t_last) { + last_used = &slot; + t_last = slot.t_last_used; } + } - slot.i_batch = batch.n_tokens; + return last_used; +} - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, - {slot.id}, true); +bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) { + SlotParams default_params; + llama_sampling_params default_sparams; - slot.n_decoded += 1; - slot.n_past += 1; + if (data.count("__oaicompat") != 0) { + slot->oaicompat = true; + slot->oaicompat_model = + json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + } else { + slot->oaicompat = false; + slot->oaicompat_model = ""; + } - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); - } + slot->params.stream = json_value(data, "stream", false); + slot->params.cache_prompt = json_value(data, "cache_prompt", false); + slot->params.n_predict = + json_value(data, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = + json_value(data, "typical_p", default_sparams.typical_p); + slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot->sparams.penalty_last_n = + json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot->sparams.penalty_repeat = + json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot->sparams.penalty_freq = + json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot->sparams.penalty_present = + json_value(data, "presence_penalty", default_sparams.penalty_present); + slot->sparams.mirostat = + json_value(data, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = + json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = + json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = + json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); + slot->params.seed = json_value(data, "seed", default_params.seed); + slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - LOG_TRACE << "slot decode token - " - << " id_slot: " << slot.id << ", task_id: " << slot.task_id - << ", n_ctx: " << n_ctx << ", n_past: " << slot.n_past - << ", n_system_tokens: " << system_tokens.size() - << ", n_cache_tokens: " << slot.cache_tokens.size() - << ", truncated: " << slot.truncated; + // infill + if (data.count("input_prefix") != 0) { + slot->params.input_prefix = data["input_prefix"]; + } else { + slot->params.input_prefix = ""; } - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); + if (data.count("input_suffix") != 0) { + slot->params.input_suffix = data["input_suffix"]; + } else { + slot->params.input_suffix = ""; + } - // assign workload to the slots - if (params.cont_batching || batch.n_tokens == 0) { - for (auto& slot : slots) { - const bool has_prompt = slot.prompt.is_array() || - (slot.prompt.is_string() && - !slot.prompt.get().empty()) || - !slot.images.empty(); + if (data.count("prompt") != 0) { + slot->prompt = data["prompt"]; + } else { + slot->prompt = ""; + } - // empty prompt passed -> release the slot and send empty response - if (slot.state == SlotState::kIdle && - slot.command == SlotCommand::kLoadPrompt && !has_prompt) { - slot.Release(); - slot.PrintTimings(); - SendFinalResponse(slot); - continue; + slot->sparams.penalty_prompt_tokens.clear(); + slot->sparams.use_penalty_prompt_tokens = false; + const auto& penalty_prompt = data.find("penalty_prompt"); + if (penalty_prompt != data.end()) { + if (penalty_prompt->is_string()) { + const auto penalty_prompt_string = penalty_prompt->get(); + auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false); + slot->sparams.penalty_prompt_tokens.swap(penalty_tokens); + if (slot->params.n_predict > 0) { + slot->sparams.penalty_prompt_tokens.reserve( + slot->sparams.penalty_prompt_tokens.size() + + slot->params.n_predict); } + slot->sparams.use_penalty_prompt_tokens = true; + } else if (penalty_prompt->is_array()) { + const auto n_tokens = penalty_prompt->size(); + slot->sparams.penalty_prompt_tokens.reserve( + n_tokens + std::max(0, slot->params.n_predict)); + const int n_vocab = llama_n_vocab(model); + for (const auto& penalty_token : *penalty_prompt) { + if (penalty_token.is_number_integer()) { + const auto tok = penalty_token.get(); + if (tok >= 0 && tok < n_vocab) { + slot->sparams.penalty_prompt_tokens.push_back(tok); + } + } + } + slot->sparams.use_penalty_prompt_tokens = true; + } + } - // need process the prompt - if (slot.state == SlotState::kIdle && - slot.command == SlotCommand::kLoadPrompt) { - auto& prompt_tokens = slot.prompt_tokens; - - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty()) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_genereration = 0; + slot->sparams.logit_bias.clear(); - if (slot.infill) { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && - params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - auto prefix_tokens = Tokenize(slot.params.input_prefix, false); - auto suffix_tokens = Tokenize(slot.params.input_suffix, false); + if (json_value(data, "ignore_eos", false)) { + slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } - const int space_token = - 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && - suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - prefix_tokens.insert(prefix_tokens.begin(), - llama_token_prefix(model)); - prefix_tokens.insert(prefix_tokens.begin(), - llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), - llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), - suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(model)); - prompt_tokens = prefix_tokens; - } else { - prompt_tokens = Tokenize( - slot.prompt, - system_prompt.empty() && - add_bos_token); // add BOS if there isn't system prompt + const auto& logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto& el : *logit_bias) { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + if (el[1].is_number()) { + slot->sparams.logit_bias[tok] = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + slot->sparams.logit_bias[tok] = -INFINITY; } + } + } + } + } - slot.n_past = 0; - slot.num_prompt_tokens = prompt_tokens.size(); + slot->params.antiprompt.clear(); - LOG_VERBOSE( - "prompt tokenized", - { - {"id_slot", slot.id}, - {"id_task", slot.task_id}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_prompt_tokens", slot.num_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), - prompt_tokens.cend())}, - }); + const auto& stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto& word : *stop) { + if (!word.empty()) { + slot->params.antiprompt.push_back(word); + } + } + } - if (slot.embedding) { - // this prompt is too large to process - discard it - if (slot.num_prompt_tokens > n_ubatch) { - LOG_DEBUG << "This prompt is too large to process: " - "num_promt_tokens = " - << slot.num_prompt_tokens - << ", n_ubatch = " << n_ubatch; - slot.state = SlotState::kProcessing; - slot.command = SlotCommand::kNone; - slot.Release(); - slot.PrintTimings(); - SendFinalResponse(slot); - continue; - } - } else { - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.num_prompt_tokens; + if (multimodal) { + const auto& images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) { + for (const auto& img : *images_data) { + const std::vector image_buffer = + base64_decode(img["data"].get()); + + SlotImage img_sl; + img_sl.id = + img.count("id") != 0 ? img["id"].get() : slot->images.size(); + img_sl.img_data = clip_image_u8_init(); + if (!clip_image_load_from_bytes(image_buffer.data(), + image_buffer.size(), img_sl.img_data)) { + LOG_DEBUG << "slot " << slot->id + << " - failed to load image [id: " << img_sl.id << "]"; + return false; + } + LOG_DEBUG << "slot " << slot->id << " - loaded image"; + img_sl.request_encode_image = true; + slot->images.push_back(img_sl); + } + // process prompt + // example: system prompt [img-102] user [img-103] describe [img-134] -> + // [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, + // {id: 134, prefix: ' describe '}]} + if (slot->images.size() > 0 && !slot->prompt.is_array()) { + std::string prompt = slot->prompt.get(); + size_t pos = 0, begin_prefix = 0; + std::string pattern = "[img-"; + while ((pos = prompt.find(pattern, pos)) != std::string::npos) { + size_t end_prefix = pos; + pos += pattern.length(); + size_t end_pos = prompt.find("]", pos); + if (end_pos != std::string::npos) { + std::string image_id = prompt.substr(pos, end_pos - pos); + try { + int img_id = std::stoi(image_id); + bool found = false; + for (SlotImage& img : slot->images) { + if (img.id == img_id) { + found = true; + img.prefix_prompt = + prompt.substr(begin_prefix, end_prefix - begin_prefix); + begin_prefix = end_pos + 1; + break; + } + } + if (!found) { + LOG_WARN << "ERROR: Image with id: " << img_id + << ", not found.\n"; + slot->images.clear(); + return false; + } + } catch (const std::invalid_argument& e) { + LOG_WARN << "Invalid image number id in prompt: " << e.what(); + slot->images.clear(); + return false; } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + } + } + slot->prompt = ""; + slot->params.input_suffix = prompt.substr(begin_prefix); + slot->params.cache_prompt = + false; // multimodal doesn't support cache prompt + } + } + } - // if input prompt is too big, truncate it - if (slot.num_prompt_tokens >= slot.n_ctx) { - const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = - (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / - n_block_size; + if (slot->ctx_sampling != nullptr) { + llama_sampling_free(slot->ctx_sampling); + } + slot->ctx_sampling = llama_sampling_init(slot->sparams); + llama_set_rng_seed(ctx, slot->params.seed); + slot->command = SlotCommand::kLoadPrompt; + slot->prompt_tokens.clear(); - std::vector new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert(new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + - erased_blocks * n_block_size, - prompt_tokens.end()); + all_slots_are_idle = false; - LOG_VERBOSE("input truncated", - { - {"id_slot", slot.id}, - {"id_task", slot.task_id}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.num_prompt_tokens}, - {"prompt_tokens", - tokens_to_str(ctx, prompt_tokens.cbegin(), - prompt_tokens.cend())}, - }); + LOG_DEBUG << "slot " << slot->id + << " is processing [task id: " << slot->task_id << "]"; - slot.truncated = true; - prompt_tokens = new_tokens; + return true; +} - slot.num_prompt_tokens = prompt_tokens.size(); - GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); - } +void LlamaServerContext::UpdateSystemPrompt() { + system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); - llama_sampling_reset(slot.ctx_sampling); + llama_batch_clear(batch); - if (!slot.params.cache_prompt) { - slot.n_past = 0; - slot.num_prompt_tokens_processed = slot.num_prompt_tokens; - } else { - // push the prompt into the sampling context (do not apply grammar) - for (auto& token : prompt_tokens) { - llama_sampling_accept(slot.ctx_sampling, ctx, token, false); - } + KvCacheClear(); - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); - slot.num_prompt_tokens_processed = - slot.num_prompt_tokens - slot.n_past; + for (int i = 0; i < (int)system_tokens.size(); ++i) { + llama_batch_add(batch, system_tokens[i], i, {0}, false); + } - LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", - slot.id, slot.n_past, slot.num_prompt_tokens_processed); - } - } + if (llama_decode(ctx, batch) != 0) { + LOG_WARN << __func__ << ": llama_decode() failed"; + return; + } - if (slot.n_past == slot.num_prompt_tokens) { - // we have to evaluate at least 1 token to generate logits. - LOG_DEBUG << "slot " << slot.id - << " : we have to evaluate at least 1 token to " - "generate logits"; - slot.n_past--; - } + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < params.n_parallel; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + } - slot.num_prompt_tokens_processed = 0; - } + LOG_DEBUG << "system prompt updated"; + system_need_update = false; +} - if (slot.embedding) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.num_prompt_tokens > n_batch) { - continue; - } - } +void LlamaServerContext::NotifySystemPromptChanged() { + // release all slots + for (LlamaClientSlot& slot : slots) { + slot.Release(); + } - LOG_VERBOSE( - "prompt ingested", - { - {"n_past", slot.n_past}, - {"cached", - tokens_to_str(ctx, slot.cache_tokens.cbegin(), - slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", - tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, - slot.cache_tokens.cend())}, - }); + system_need_update = true; +} - // keep only the common part - int p0 = (int)system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); +void LlamaServerContext::ProcessSystemPromptData(const json& sys_props) { + system_prompt = sys_props.value("prompt", ""); + name_user = sys_props.value("anti_prompt", ""); + name_assistant = sys_props.value("assistant_name", ""); - p0 = (int)system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1); - } + if (slots.size() > 0) { + NotifySystemPromptChanged(); + } +} - // there is no common part left (except for the system prompt) - slot.n_past = 0; - // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling); - } +size_t LlamaServerContext::FindStoppingStrings(const std::string& text, + const size_t last_token_size, + const StopType type, + LlamaClientSlot& slot) { + size_t stop_pos = std::string::npos; - // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); - LOG_INFO << "kv cache rm [p0, end) - " - << " id_slot: " << slot.id << ", task_id: " << slot.task_id - << ", p0: " << p0; + for (const std::string& word : slot.params.antiprompt) { + size_t pos; + if (type == StopType::kStopFull) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } else { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && + (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == StopType::kStopFull) { + slot.stopped_word = true; + slot.stopping_word = word; + slot.has_next_token = false; + } + stop_pos = pos; + } + } - const bool has_images = ProcessImages(slot); + return stop_pos; +} - // process the prefix of first image - std::vector prefix_tokens = - has_images ? Tokenize(slot.images[0].prefix_prompt, add_bos_token) - : prompt_tokens; - for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], - system_tokens.size() + slot.n_past, {slot.id}, false); - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); - } - slot.num_prompt_tokens_processed++; - } +bool LlamaServerContext::ProcessToken(CompletionTokenOutput& result, + LlamaClientSlot& slot) { + // remember which tokens were sampled - used for repetition penalties during + // sampling + const std::string token_str = llama_token_to_piece(ctx, result.tok); + slot.sampled = result.tok; + + // search stop word and delete it + slot.generated_text += token_str; + slot.has_next_token = true; + + if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { + // we can change penalty_prompt_tokens because it is always created from + // scratch each request + slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + } + + // check if there is incomplete UTF-8 character at the end + bool incomplete = false; + for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { + unsigned char c = slot.generated_text[slot.generated_text.size() - i]; + if ((c & 0xC0) == 0x80) { + // continuation byte: 10xxxxxx + continue; + } + if ((c & 0xE0) == 0xC0) { + // 2-byte character: 110xxxxx ... + incomplete = i < 2; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character: 1110xxxx ... + incomplete = i < 3; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character: 11110xxx ... + incomplete = i < 4; + } + // else 1-byte character or invalid byte + break; + } + + if (!incomplete) { + size_t pos = std::min(slot.sent_count, slot.generated_text.size()); + const std::string str_test = slot.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = FindStoppingStrings(str_test, token_str.size(), + StopType::kStopFull, slot); + if (stop_pos != std::string::npos) { + is_stop_full = true; + slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.sent_count, slot.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = FindStoppingStrings(str_test, token_str.size(), + StopType::kStopPartial, slot); + } + + // check if there is any token to predict + if (stop_pos == std::string::npos || + (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.sent_count += result.text_to_send.size(); + // add the token to slot queue and cache + } + slot.AddTokenString(result); + if (slot.params.stream) { + SendPartialResponse(slot, result); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // check the limits + if (slot.n_decoded > 2 && slot.has_next_token && !slot.HasBudget(params)) { + slot.stopped_limit = true; + slot.has_next_token = false; + } + + if (result.tok == llama_token_eos(model)) { + slot.stopped_eos = true; + slot.has_next_token = false; + LOG_VERBOSE("eos token found", {}); + } + + LOG_VERBOSE( + "next token", + { + {"token", result.tok}, + {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"num_tokens_predicted", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }); + + return slot.has_next_token; // continue +} +bool LlamaServerContext::ProcessImages(LlamaClientSlot& slot) const { + for (SlotImage& img : slot.images) { + if (!img.request_encode_image) { + continue; + } - LOG_VERBOSE("prompt processing progress", - { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - {"progress", (float)slot.num_prompt_tokens_processed / - slot.num_prompt_tokens}, - }); + if (!llava_image_embed_make_with_clip_img( + clp_ctx, params.n_threads, img.img_data, &img.image_embedding, + &img.image_tokens)) { + LOG_WARN << "Error processing the given image"; + return false; + } - if (has_images && !IngestImages(slot, n_batch)) { - LOG_WARN << "failed processing images"; - return false; - } + img.request_encode_image = false; + } - // entire prompt has been processed - start decoding new tokens - if (slot.n_past == slot.num_prompt_tokens) { - slot.state = SlotState::kProcessing; - slot.command = SlotCommand::kNone; + return slot.images.size() > 0; +} +void LlamaServerContext::SendError(TaskServer& task, std::string error) { + std::unique_lock lock(mutex_results); + TaskResult res; + res.id = task.id; + res.multitask_id = task.multitask_id; + res.stop = false; + res.error = true; + res.result_json = {{"content", error}}; + queue_results.push_back(res); + condition_results.notify_all(); +} - GGML_ASSERT(batch.n_tokens > 0); +void LlamaServerContext::AddMultiTask(int id, std::vector& sub_ids) { + std::lock_guard lock(mutex_tasks); + TaskMulti multi; + multi.id = id; + std::copy( + sub_ids.begin(), sub_ids.end(), + std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); + condition_tasks.notify_one(); +} - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; +void LlamaServerContext::UpdateMultiTask(int multitask_id, int subtask_id, + TaskResult& result) { + std::lock_guard lock(mutex_tasks); + for (auto& multitask : queue_multitasks) { + if (multitask.id == multitask_id) { + multitask.subtasks_remaining.erase(subtask_id); + multitask.results.push_back(result); + condition_tasks.notify_one(); + } + } +} - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; +json LlamaServerContext::GetFormatedGeneration(LlamaClientSlot& slot) { + const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); + const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && + eos_bias->second < 0.0f && + std::isinf(eos_bias->second); + return json{ + {"n_ctx", slot.n_ctx}, + {"model", params.model_alias}, + {"seed", slot.params.seed}, + {"temperature", slot.sparams.temp}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"min_p", slot.sparams.min_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.penalty_last_n}, + {"repeat_penalty", slot.sparams.penalty_repeat}, + {"presence_penalty", slot.sparams.penalty_present}, + {"frequency_penalty", slot.sparams.penalty_freq}, + {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, + {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"n_predict", slot.params.n_predict}, + {"n_keep", params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"grammar", slot.sparams.grammar}, + }; +} - LOG_VERBOSE("prompt done", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); - } - } +void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot, + CompletionTokenOutput tkn) { + std::unique_lock lock(mutex_results); + TaskResult res; + res.id = slot.task_id; + res.multitask_id = slot.multitask_id; + res.error = false; + res.stop = false; + + res.result_json = json{{"content", tkn.text_to_send}, + {"stop", false}, + {"slot_id", slot.id}, + {"multimodal", multimodal}}; + + if (slot.sparams.n_probs > 0) { + std::vector probs_output = {}; + const std::vector to_send_toks = + llama_tokenize(ctx, tkn.text_to_send, false); + size_t probs_pos = std::min(slot.sent_token_probs_index, + slot.generated_token_probs.size()); + size_t probs_stop_pos = + std::min(slot.sent_token_probs_index + to_send_toks.size(), + slot.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector( + slot.generated_token_probs.begin() + probs_pos, + slot.generated_token_probs.begin() + probs_stop_pos); } + slot.sent_token_probs_index = probs_stop_pos; + res.result_json["completion_probabilities"] = + probs_vector_to_json(ctx, probs_output); } - if (batch.n_tokens == 0) { - all_slots_are_idle = true; - return true; + if (slot.oaicompat) { + res.result_json["oaicompat_token_ctr"] = slot.n_decoded; + res.result_json["model"] = slot.oaicompat_model; } - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused - }; + queue_results.push_back(res); + condition_results.notify_all(); +} - const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - if (n_batch == 1 || ret < 0) { - // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_ERROR << "Failed to decode the batch: KV cache is full - try " - "increasing it via the context size: " - << "i = " << i << ", n_batch = " << n_batch - << ", ret = " << ret; - for (auto& slot : slots) { - slot.state = SlotState::kProcessing; - slot.command = SlotCommand::kNone; - slot.Release(); - // SendError(slot, - // "Input prompt is too big compared to KV size. Please " - // "try increasing KV size."); - } - break; // break loop of n_batch - } +void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) { + std::unique_lock lock(mutex_results); + TaskResult res; + res.id = slot.task_id; + res.multitask_id = slot.multitask_id; + res.error = false; + res.stop = true; - LOG_WARN << "Failed to find free space in the KV cache, retrying with " - "smaller n_batch = " - << n_batch / 2; + res.result_json = + json{{"content", !slot.params.stream ? slot.generated_text : ""}, + {"slot_id", slot.id}, + {"stop", true}, + {"model", params.model_alias}, + {"tokens_predicted", slot.n_decoded}, + {"tokens_evaluated", slot.num_prompt_tokens}, + {"generation_settings", GetFormatedGeneration(slot)}, + {"prompt", slot.prompt}, + {"truncated", slot.truncated}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + {"tokens_cached", slot.n_past}, + {"timings", slot.GetFormatedTimings()}}; - // retry with half the batch size to try to find a free slot in the KV - // cache - n_batch /= 2; - i -= n_batch; - continue; + if (slot.sparams.n_probs > 0) { + std::vector probs = {}; + if (!slot.params.stream && slot.stopped_word) { + const std::vector stop_word_toks = + llama_tokenize(ctx, slot.stopping_word, false); + probs = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - stop_word_toks.size()); + } else { + probs = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.begin() + slot.sent_token_probs_index); } + res.result_json["completion_probabilities"] = + probs_vector_to_json(ctx, probs); + } - for (auto& slot : slots) { - if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { - continue; - } + if (slot.oaicompat) { + res.result_json["oaicompat_token_ctr"] = slot.n_decoded; + res.result_json["model"] = slot.oaicompat_model; + } - // prompt evaluated for embedding - if (slot.embedding) { - SendEmbedding(slot); - slot.Release(); - slot.i_batch = -1; - return true; - } + // parent multitask, if any, needs to be updated + if (slot.multitask_id != -1) { + UpdateMultiTask(slot.multitask_id, slot.task_id, res); + } - CompletionTokenOutput result; - const llama_token id = - llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + queue_results.push_back(res); + condition_results.notify_all(); +} + +void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) { + std::unique_lock lock(mutex_results); + TaskResult res; + res.id = slot.task_id; + res.multitask_id = slot.multitask_id; + res.error = false; + res.stop = true; - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + const int n_embd = llama_n_embd(model); - if (slot.n_decoded == 1) { - slot.t_start_genereration = ggml_time_us(); - slot.t_prompt_processing = - (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; - } + std::vector embd_res(n_embd, 0.0f); - llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), - slot.ctx_sampling->cur.size(), false}; - result.tok = id; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } - for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { - result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); - } + if (embd == NULL) { + LOG_ERROR << "failed to get embeddings" << " token " << batch.token[i] + << ", seq_id " << batch.seq_id[i][0]; - if (!ProcessToken(result, slot)) { - slot.Release(); - slot.PrintTimings(); - SendFinalResponse(slot); - } + res.result_json = json{ + {"embedding", std::vector(n_embd, 0.0f)}, + }; - slot.i_batch = -1; + continue; } + + llama_embd_normalize(embd, embd_res.data(), n_embd); } - return true; + res.result_json = json{ + {"embedding", embd_res}, + }; + + queue_results.push_back(res); + condition_results.notify_all(); } -std::vector LlamaServerContext::Tokenize(const json& json_prompt, - bool add_bos) const { - // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see - // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) - // but it's better compared to completely ignoring ChatML and other - // chat templates - const bool TMP_FORCE_SPECIAL = true; +// for multiple images processing +bool LlamaServerContext::IngestImages(LlamaClientSlot& slot, int n_batch) { + int image_idx = 0; - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; + while (image_idx < (int)slot.images.size()) { + SlotImage& img = slot.images[image_idx]; - if (json_prompt.is_array()) { - bool first = true; - for (const auto& p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - std::vector p; - if (first) { - p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); - first = false; - } else { - p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); - } - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - prompt_tokens.push_back(p.template get()); + // process prefix prompt + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, + 0, + 0, // unused + }; + if (llama_decode(ctx, batch_view)) { + LOG_WARN << __func__ << " : failed to eval\n"; + return false; } } - } else { - auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); - } - - return prompt_tokens; -} -LlamaClientSlot* LlamaServerContext::GetSlot(int id) { - int64_t t_last = ggml_time_us(); - LlamaClientSlot* last_used = nullptr; + // process image with llm + for (int i = 0; i < img.image_tokens; i += n_batch) { + int n_eval = img.image_tokens - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } - for (LlamaClientSlot& slot : slots) { - if (slot.id == id && slot.Available()) { - return &slot; + const int n_embd = llama_n_embd(model); + llama_batch batch_img = { + n_eval, nullptr, (img.image_embedding + i * n_embd), + nullptr, nullptr, nullptr, + nullptr, slot.n_past, 1, + 0, + }; + if (llama_decode(ctx, batch_img)) { + LOG_DEBUG << __func__ << " : failed to eval image"; + return false; + } + slot.n_past += n_eval; } + image_idx++; - if (slot.Available() && slot.t_last_used < t_last) { - last_used = &slot; - t_last = slot.t_last_used; + llama_batch_clear(batch); + + // append prefix of next image + const auto json_prompt = + (image_idx >= (int)slot.images.size()) + ? slot.params.input_suffix + : // no more images, then process suffix prompt + (json)(slot.images[image_idx].prefix_prompt); + + std::vector append_tokens = + Tokenize(json_prompt, false); // has next image + for (int i = 0; i < (int)append_tokens.size(); ++i) { + llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true); + slot.n_past += 1; } } - return last_used; + return true; } -bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) { - SlotParams default_params; - llama_sampling_params default_sparams; - - if (data.count("__oaicompat") != 0) { - slot->oaicompat = true; - slot->oaicompat_model = - json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - } else { - slot->oaicompat = false; - slot->oaicompat_model = ""; - } - - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = - json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = - json_value(data, "typical_p", default_sparams.typical_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.penalty_last_n = - json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = - json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = - json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = - json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = - json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = - json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = - json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = - json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(data, "seed", default_params.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); +int LlamaServerContext::SplitMultipromptTask(TaskServer& multiprompt_task) { + int prompt_count = multiprompt_task.data.at("prompt").size(); + assert(prompt_count > 1); - // infill - if (data.count("input_prefix") != 0) { - slot->params.input_prefix = data["input_prefix"]; - } else { - slot->params.input_prefix = ""; - } + int multitask_id = id_gen++; + std::vector subtask_ids(prompt_count); + for (int i = 0; i < prompt_count; i++) { + json subtask_data = multiprompt_task.data; + subtask_data["prompt"] = subtask_data["prompt"][i]; - if (data.count("input_suffix") != 0) { - slot->params.input_suffix = data["input_suffix"]; - } else { - slot->params.input_suffix = ""; + // subtasks inherit everything else (infill mode, embedding mode, etc.) + subtask_ids[i] = + RequestCompletion(subtask_data, multiprompt_task.infill_mode, + multiprompt_task.embedding_mode, multitask_id); } - if (data.count("prompt") != 0) { - slot->prompt = data["prompt"]; - } else { - slot->prompt = ""; - } + // queue up the multitask so we can track its subtask progression + AddMultiTask(multitask_id, subtask_ids); + return multitask_id; +} - slot->sparams.penalty_prompt_tokens.clear(); - slot->sparams.use_penalty_prompt_tokens = false; - const auto& penalty_prompt = data.find("penalty_prompt"); - if (penalty_prompt != data.end()) { - if (penalty_prompt->is_string()) { - const auto penalty_prompt_string = penalty_prompt->get(); - auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false); - slot->sparams.penalty_prompt_tokens.swap(penalty_tokens); - if (slot->params.n_predict > 0) { - slot->sparams.penalty_prompt_tokens.reserve( - slot->sparams.penalty_prompt_tokens.size() + - slot->params.n_predict); - } - slot->sparams.use_penalty_prompt_tokens = true; - } else if (penalty_prompt->is_array()) { - const auto n_tokens = penalty_prompt->size(); - slot->sparams.penalty_prompt_tokens.reserve( - n_tokens + std::max(0, slot->params.n_predict)); - const int n_vocab = llama_n_vocab(model); - for (const auto& penalty_token : *penalty_prompt) { - if (penalty_token.is_number_integer()) { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) { - slot->sparams.penalty_prompt_tokens.push_back(tok); - } +void LlamaServerContext::ProcessTasks() { + std::unique_lock lock(mutex_tasks); + while (!queue_tasks.empty()) { + TaskServer task = queue_tasks.front(); + queue_tasks.erase(queue_tasks.begin()); + switch (task.type) { + case TaskType::kCompletionTask: { + LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1)); + if (slot == nullptr) { + LOG_WARN << "slot unavailable"; + // send error result + SendError(task, "slot unavailable"); + return; } - } - slot->sparams.use_penalty_prompt_tokens = true; - } - } - slot->sparams.logit_bias.clear(); + if (task.data.contains("system_prompt")) { + ProcessSystemPromptData(task.data["system_prompt"]); + } - if (json_value(data, "ignore_eos", false)) { - slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } + slot->Reset(); - const auto& logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto& el : *logit_bias) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - if (el[1].is_number()) { - slot->sparams.logit_bias[tok] = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - slot->sparams.logit_bias[tok] = -INFINITY; + slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; + slot->task_id = task.id; + slot->multitask_id = task.multitask_id; + + if (!LaunchSlotWithData(slot, task.data)) { + // send error result + SendError(task, "internal_error"); + break; + } + } break; + case TaskType::kCancelTask: { // release slot linked with the task id + for (auto& slot : slots) { + if (slot.task_id == task.target_id) { + slot.Release(); + break; } } - } + } break; } } - slot->params.antiprompt.clear(); + // remove finished multitasks from the queue of multitasks, and add the + // corresponding result to the result queue + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) { + if (queue_iterator->subtasks_remaining.empty()) { + // all subtasks done == multitask is done + TaskResult aggregate_result; + aggregate_result.id = queue_iterator->id; + aggregate_result.stop = true; + aggregate_result.error = false; - const auto& stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto& word : *stop) { - if (!word.empty()) { - slot->params.antiprompt.push_back(word); + // collect json results into one json result + std::vector result_jsons; + for (auto& subres : queue_iterator->results) { + result_jsons.push_back(subres.result_json); + aggregate_result.error = aggregate_result.error && subres.error; } - } - } + aggregate_result.result_json = json{"results", result_jsons}; - if (multimodal) { - const auto& images_data = data.find("image_data"); - if (images_data != data.end() && images_data->is_array()) { - for (const auto& img : *images_data) { - const std::vector image_buffer = - base64_decode(img["data"].get()); + std::lock_guard lock(mutex_results); + queue_results.push_back(aggregate_result); + condition_results.notify_all(); - SlotImage img_sl; - img_sl.id = - img.count("id") != 0 ? img["id"].get() : slot->images.size(); - img_sl.img_data = clip_image_u8_init(); - if (!clip_image_load_from_bytes(image_buffer.data(), - image_buffer.size(), img_sl.img_data)) { - LOG_DEBUG << "slot " << slot->id - << " - failed to load image [id: " << img_sl.id << "]"; - return false; - } - LOG_DEBUG << "slot " << slot->id << " - loaded image"; - img_sl.request_encode_image = true; - slot->images.push_back(img_sl); - } - // process prompt - // example: system prompt [img-102] user [img-103] describe [img-134] -> - // [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, - // {id: 134, prefix: ' describe '}]} - if (slot->images.size() > 0 && !slot->prompt.is_array()) { - std::string prompt = slot->prompt.get(); - size_t pos = 0, begin_prefix = 0; - std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) { - size_t end_prefix = pos; - pos += pattern.length(); - size_t end_pos = prompt.find("]", pos); - if (end_pos != std::string::npos) { - std::string image_id = prompt.substr(pos, end_pos - pos); - try { - int img_id = std::stoi(image_id); - bool found = false; - for (SlotImage& img : slot->images) { - if (img.id == img_id) { - found = true; - img.prefix_prompt = - prompt.substr(begin_prefix, end_prefix - begin_prefix); - begin_prefix = end_pos + 1; - break; - } - } - if (!found) { - LOG_WARN << "ERROR: Image with id: " << img_id - << ", not found.\n"; - slot->images.clear(); - return false; - } - } catch (const std::invalid_argument& e) { - LOG_WARN << "Invalid image number id in prompt: " << e.what(); - slot->images.clear(); - return false; - } - } - } - slot->prompt = ""; - slot->params.input_suffix = prompt.substr(begin_prefix); - slot->params.cache_prompt = - false; // multimodal doesn't support cache prompt - } + queue_iterator = queue_multitasks.erase(queue_iterator); + } else { + ++queue_iterator; } } +} - if (slot->ctx_sampling != nullptr) { - llama_sampling_free(slot->ctx_sampling); +void LlamaServerContext::DoBackgroundTasks() { + while (model_loaded_external) { + UpdateSlots(); } - slot->ctx_sampling = llama_sampling_init(slot->sparams); - llama_set_rng_seed(ctx, slot->params.seed); - slot->command = SlotCommand::kLoadPrompt; - slot->prompt_tokens.clear(); - - all_slots_are_idle = false; - - LOG_DEBUG << "slot " << slot->id - << " is processing [task id: " << slot->task_id << "]"; - - return true; + LOG_INFO << "Background task stopped! "; + KvCacheClear(); + LOG_INFO << "KV cache cleared!"; } -void LlamaServerContext::UpdateSystemPrompt() { - system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); - - llama_batch_clear(batch); - - KvCacheClear(); +bool LlamaServerContext::UpdateSlots() { + // attend tasks + ProcessTasks(); - for (int i = 0; i < (int)system_tokens.size(); ++i) { - llama_batch_add(batch, system_tokens[i], i, {0}, false); + // update the system prompt wait until all slots are idle state + if (system_need_update && all_slots_are_idle) { + LOG_DEBUG << "updating system prompt"; + UpdateSystemPrompt(); } - if (llama_decode(ctx, batch) != 0) { - LOG_WARN << __func__ << ": llama_decode() failed"; - return; - } + llama_batch_clear(batch); - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + if (all_slots_are_idle) { + if (system_prompt.empty() && clean_kv_cache) { + LOG_DEBUG + << "all slots are idle and system prompt is empty, clear the KV " + "cache"; + KvCacheClear(); + } + // std::this_thread::sleep_for(std::chrono::milliseconds(5)); + // TODO: Need to implement queueing using CV for better performance + std::unique_lock lock(mutex_tasks); + condition_tasks.wait(lock, [&] { + return (!queue_tasks.empty() && model_loaded_external) || + (queue_tasks.empty() && !model_loaded_external); + }); } - LOG_DEBUG << "system prompt updated"; - system_need_update = false; -} - -void LlamaServerContext::NotifySystemPromptChanged() { - // release all slots for (LlamaClientSlot& slot : slots) { - slot.Release(); - } - - system_need_update = true; -} - -void LlamaServerContext::ProcessSystemPromptData(const json& sys_props) { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); - name_assistant = sys_props.value("assistant_name", ""); - - if (slots.size() > 0) { - NotifySystemPromptChanged(); - } -} + if (slot.IsProcessing() && + (int)system_tokens.size() + slot.n_past >= slot.n_ctx) { + // Shift context + const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_discard = n_left / 2; -size_t LlamaServerContext::FindStoppingStrings(const std::string& text, - const size_t last_token_size, - const StopType type, - LlamaClientSlot& slot) { - size_t stop_pos = std::string::npos; + LOG_DEBUG << "slot " << slot.id + << " context shift - n_keep = " << slot.params.n_keep + << ", n_left = " << n_left << ", n_discard: " << n_discard + << ", n_ctx = " << n_ctx << ", n_past = " << slot.n_past + << ", n_system_tokens = " << system_tokens.size() + << ", n_cache_tokens = " << slot.cache_tokens.size(); - for (const std::string& word : slot.params.antiprompt) { - size_t pos; - if (type == StopType::kStopFull) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } else { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == StopType::kStopFull) { - slot.stopped_word = true; - slot.stopping_word = word; - slot.has_next_token = false; - } - stop_pos = pos; - } - } + llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, + slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_add(ctx, slot.id, slot.params.n_keep + 1 + n_discard, + slot.n_past, -n_discard); - return stop_pos; -} + if (slot.params.cache_prompt) { + for (size_t i = slot.params.n_keep + 1 + n_discard; + i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } -bool LlamaServerContext::ProcessToken(CompletionTokenOutput& result, - LlamaClientSlot& slot) { - // remember which tokens were sampled - used for repetition penalties during - // sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok); - slot.sampled = result.tok; + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + } - // search stop word and delete it - slot.generated_text += token_str; - slot.has_next_token = true; + slot.n_past -= n_discard; - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { - // we can change penalty_prompt_tokens because it is always created from - // scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + slot.truncated = true; + } } - // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) { - // continuation byte: 10xxxxxx + // decode any currently ongoing sequences + for (auto& slot : slots) { + // release the slot + if (slot.command == SlotCommand::kRelease) { + slot.state = SlotState::kIdle; + slot.command = SlotCommand::kNone; + slot.t_last_used = ggml_time_us(); + + LOG_INFO << "slot released: " << "id_slot: " << slot.id + << ", id_task: " << slot.task_id << ", n_ctx: " << n_ctx + << ", n_past: " << slot.n_past + << ", n_system_tokens: " << system_tokens.size() + << ", n_cache_tokens: " << slot.cache_tokens.size() + << ", truncated: " << slot.truncated; + continue; } - if ((c & 0xE0) == 0xC0) { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } - if (!incomplete) { - size_t pos = std::min(slot.sent_count, slot.generated_text.size()); - const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = FindStoppingStrings(str_test, token_str.size(), - StopType::kStopFull, slot); - if (stop_pos != std::string::npos) { - is_stop_full = true; - slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.sent_count, slot.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = FindStoppingStrings(str_test, token_str.size(), - StopType::kStopPartial, slot); + if (slot.state == SlotState::kIdle) { + continue; } - // check if there is any token to predict - if (stop_pos == std::string::npos || - (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.sent_count += result.text_to_send.size(); - // add the token to slot queue and cache - } - slot.AddTokenString(result); - if (slot.params.stream) { - SendPartialResponse(slot, result); - } - } + slot.i_batch = batch.n_tokens; - if (incomplete) { - slot.has_next_token = true; - } + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, + {slot.id}, true); - // check the limits - if (slot.n_decoded > 2 && slot.has_next_token && !slot.HasBudget(params)) { - slot.stopped_limit = true; - slot.has_next_token = false; - } + slot.n_decoded += 1; + slot.n_past += 1; - if (result.tok == llama_token_eos(model)) { - slot.stopped_eos = true; - slot.has_next_token = false; - LOG_VERBOSE("eos token found", {}); + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + + LOG_TRACE << "slot decode token - " << " id_slot: " << slot.id + << ", task_id: " << slot.task_id << ", n_ctx: " << n_ctx + << ", n_past: " << slot.n_past + << ", n_system_tokens: " << system_tokens.size() + << ", n_cache_tokens: " << slot.cache_tokens.size() + << ", truncated: " << slot.truncated; } - LOG_VERBOSE( - "next token", - { - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"num_tokens_predicted", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); - return slot.has_next_token; // continue -} -bool LlamaServerContext::ProcessImages(LlamaClientSlot& slot) const { - for (SlotImage& img : slot.images) { - if (!img.request_encode_image) { - continue; - } + // assign workload to the slots + if (params.cont_batching || batch.n_tokens == 0) { + for (auto& slot : slots) { + const bool has_prompt = slot.prompt.is_array() || + (slot.prompt.is_string() && + !slot.prompt.get().empty()) || + !slot.images.empty(); - if (!llava_image_embed_make_with_clip_img( - clp_ctx, params.n_threads, img.img_data, &img.image_embedding, - &img.image_tokens)) { - LOG_WARN << "Error processing the given image"; - return false; - } + // empty prompt passed -> release the slot and send empty response + if (slot.state == SlotState::kIdle && + slot.command == SlotCommand::kLoadPrompt && !has_prompt) { + slot.Release(); + slot.PrintTimings(); + SendFinalResponse(slot); + continue; + } - img.request_encode_image = false; - } + // need process the prompt + if (slot.state == SlotState::kIdle && + slot.command == SlotCommand::kLoadPrompt) { + auto& prompt_tokens = slot.prompt_tokens; - return slot.images.size() > 0; -} -void LlamaServerContext::SendError(TaskServer& task, std::string error) { - std::unique_lock lock(mutex_results); - TaskResult res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = false; - res.error = true; - res.result_json = {{"content", error}}; - queue_results.push_back(res); - condition_results.notify_all(); -} + // we haven't tokenized the prompt yet - do it now: + if (prompt_tokens.empty()) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_genereration = 0; -void LlamaServerContext::AddMultiTask(int id, std::vector& sub_ids) { - std::lock_guard lock(mutex_tasks); - TaskMulti multi; - multi.id = id; - std::copy( - sub_ids.begin(), sub_ids.end(), - std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - condition_tasks.notify_one(); -} + if (slot.infill) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(' ') == 0 && + params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + auto prefix_tokens = Tokenize(slot.params.input_prefix, false); + auto suffix_tokens = Tokenize(slot.params.input_suffix, false); -void LlamaServerContext::UpdateMultiTask(int multitask_id, int subtask_id, - TaskResult& result) { - std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) { - if (multitask.id == multitask_id) { - multitask.subtasks_remaining.erase(subtask_id); - multitask.results.push_back(result); - condition_tasks.notify_one(); - } - } -} + const int space_token = + 29871; // TODO: this should not be hardcoded + if (suff_rm_leading_spc && !suffix_tokens.empty() && + suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } -json LlamaServerContext::GetFormatedGeneration(LlamaClientSlot& slot) { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && - eos_bias->second < 0.0f && - std::isinf(eos_bias->second); - return json{ - {"n_ctx", slot.n_ctx}, - {"model", params.model_alias}, - {"seed", slot.params.seed}, - {"temperature", slot.sparams.temp}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, - {"n_keep", params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"grammar", slot.sparams.grammar}, - }; -} + prefix_tokens.insert(prefix_tokens.begin(), + llama_token_prefix(model)); + prefix_tokens.insert(prefix_tokens.begin(), + llama_token_bos(model)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), + llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), + suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(model)); + prompt_tokens = prefix_tokens; + } else { + prompt_tokens = Tokenize( + slot.prompt, + system_prompt.empty() && + add_bos_token); // add BOS if there isn't system prompt + } -void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot, - CompletionTokenOutput tkn) { - std::unique_lock lock(mutex_results); - TaskResult res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = false; + slot.n_past = 0; + slot.num_prompt_tokens = prompt_tokens.size(); - res.result_json = json{{"content", tkn.text_to_send}, - {"stop", false}, - {"slot_id", slot.id}, - {"multimodal", multimodal}}; + LOG_VERBOSE( + "prompt tokenized", + { + {"id_slot", slot.id}, + {"id_task", slot.task_id}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_prompt_tokens", slot.num_prompt_tokens}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), + prompt_tokens.cend())}, + }); - if (slot.sparams.n_probs > 0) { - std::vector probs_output = {}; - const std::vector to_send_toks = - llama_tokenize(ctx, tkn.text_to_send, false); - size_t probs_pos = std::min(slot.sent_token_probs_index, - slot.generated_token_probs.size()); - size_t probs_stop_pos = - std::min(slot.sent_token_probs_index + to_send_toks.size(), - slot.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector( - slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.sent_token_probs_index = probs_stop_pos; - res.result_json["completion_probabilities"] = - probs_vector_to_json(ctx, probs_output); - } + if (slot.embedding) { + // this prompt is too large to process - discard it + if (slot.num_prompt_tokens > n_ubatch) { + LOG_DEBUG << "This prompt is too large to process: " + "num_promt_tokens = " + << slot.num_prompt_tokens + << ", n_ubatch = " << n_ubatch; + slot.state = SlotState::kProcessing; + slot.command = SlotCommand::kNone; + slot.Release(); + slot.PrintTimings(); + SendFinalResponse(slot); + continue; + } + } else { + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.num_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - if (slot.oaicompat) { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; - } + // if input prompt is too big, truncate it + if (slot.num_prompt_tokens >= slot.n_ctx) { + const int n_left = slot.n_ctx - slot.params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = + (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / + n_block_size; - queue_results.push_back(res); - condition_results.notify_all(); -} + std::vector new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); + new_tokens.insert(new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + + erased_blocks * n_block_size, + prompt_tokens.end()); -void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) { - std::unique_lock lock(mutex_results); - TaskResult res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; + LOG_VERBOSE("input truncated", + { + {"id_slot", slot.id}, + {"id_task", slot.task_id}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_prompt_tokens", slot.num_prompt_tokens}, + {"prompt_tokens", + tokens_to_str(ctx, prompt_tokens.cbegin(), + prompt_tokens.cend())}, + }); - res.result_json = - json{{"content", !slot.params.stream ? slot.generated_text : ""}, - {"slot_id", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.num_prompt_tokens}, - {"generation_settings", GetFormatedGeneration(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.GetFormatedTimings()}}; + slot.truncated = true; + prompt_tokens = new_tokens; - if (slot.sparams.n_probs > 0) { - std::vector probs = {}; - if (!slot.params.stream && slot.stopped_word) { - const std::vector stop_word_toks = - llama_tokenize(ctx, slot.stopping_word, false); - probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - stop_word_toks.size()); - } else { - probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.begin() + slot.sent_token_probs_index); - } - res.result_json["completion_probabilities"] = - probs_vector_to_json(ctx, probs); - } + slot.num_prompt_tokens = prompt_tokens.size(); + GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); + } - if (slot.oaicompat) { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; - } + llama_sampling_reset(slot.ctx_sampling); - // parent multitask, if any, needs to be updated - if (slot.multitask_id != -1) { - UpdateMultiTask(slot.multitask_id, slot.task_id, res); - } + if (!slot.params.cache_prompt) { + slot.n_past = 0; + slot.num_prompt_tokens_processed = slot.num_prompt_tokens; + } else { + // push the prompt into the sampling context (do not apply grammar) + for (auto& token : prompt_tokens) { + llama_sampling_accept(slot.ctx_sampling, ctx, token, false); + } - queue_results.push_back(res); - condition_results.notify_all(); -} + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + slot.num_prompt_tokens_processed = + slot.num_prompt_tokens - slot.n_past; -void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) { - std::unique_lock lock(mutex_results); - TaskResult res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", + slot.id, slot.n_past, slot.num_prompt_tokens_processed); + } + } - const int n_embd = llama_n_embd(model); + if (slot.n_past == slot.num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + LOG_DEBUG << "slot " << slot.id + << " : we have to evaluate at least 1 token to " + "generate logits"; + slot.n_past--; + } - std::vector embd_res(n_embd, 0.0f); + slot.num_prompt_tokens_processed = 0; + } - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } + if (slot.embedding) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.num_prompt_tokens > n_batch) { + continue; + } + } - const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } + LOG_VERBOSE( + "prompt ingested", + { + {"n_past", slot.n_past}, + {"cached", + tokens_to_str(ctx, slot.cache_tokens.cbegin(), + slot.cache_tokens.cbegin() + slot.n_past)}, + {"to_eval", + tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, + slot.cache_tokens.cend())}, + }); - if (embd == NULL) { - LOG_ERROR << "failed to get embeddings" - << " token " << batch.token[i] << ", seq_id " - << batch.seq_id[i][0]; + // keep only the common part + int p0 = (int)system_tokens.size() + slot.n_past; + if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - res.result_json = json{ - {"embedding", std::vector(n_embd, 0.0f)}, - }; + p0 = (int)system_tokens.size(); + if (p0 != 0) { + // copy over the system prompt when there is one + llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1); + } - continue; - } + // there is no common part left (except for the system prompt) + slot.n_past = 0; + // TODO: is the system prompt ever in the sampling context? + llama_sampling_reset(slot.ctx_sampling); + } + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); + LOG_INFO << "kv cache rm [p0, end) - " << " id_slot: " << slot.id + << ", task_id: " << slot.task_id << ", p0: " << p0; + + const bool has_images = ProcessImages(slot); + + // process the prefix of first image + std::vector prefix_tokens = + has_images ? Tokenize(slot.images[0].prefix_prompt, add_bos_token) + : prompt_tokens; + for (; slot.n_past < (int)prefix_tokens.size(); ++slot.n_past) { + llama_batch_add(batch, prefix_tokens[slot.n_past], + system_tokens.size() + slot.n_past, {slot.id}, false); + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + } + slot.num_prompt_tokens_processed++; + } - llama_embd_normalize(embd, embd_res.data(), n_embd); - } - res.result_json = json{ - {"embedding", embd_res}, - }; + LOG_VERBOSE("prompt processing progress", + { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + {"progress", (float)slot.num_prompt_tokens_processed / + slot.num_prompt_tokens}, + }); - queue_results.push_back(res); - condition_results.notify_all(); -} + if (has_images && !IngestImages(slot, n_batch)) { + LOG_WARN << "failed processing images"; + return false; + } -// for multiple images processing -bool LlamaServerContext::IngestImages(LlamaClientSlot& slot, int n_batch) { - int image_idx = 0; + // entire prompt has been processed - start decoding new tokens + if (slot.n_past == slot.num_prompt_tokens) { + slot.state = SlotState::kProcessing; + slot.command = SlotCommand::kNone; - while (image_idx < (int)slot.images.size()) { - SlotImage& img = slot.images[image_idx]; + GGML_ASSERT(batch.n_tokens > 0); - // process prefix prompt - for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused - }; - if (llama_decode(ctx, batch_view)) { - LOG_WARN << __func__ << " : failed to eval\n"; - return false; - } - } + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; - // process image with llm - for (int i = 0; i < img.image_tokens; i += n_batch) { - int n_eval = img.image_tokens - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; - const int n_embd = llama_n_embd(model); - llama_batch batch_img = { - n_eval, nullptr, (img.image_embedding + i * n_embd), - nullptr, nullptr, nullptr, - nullptr, slot.n_past, 1, - 0, - }; - if (llama_decode(ctx, batch_img)) { - LOG_DEBUG << __func__ << " : failed to eval image"; - return false; + LOG_VERBOSE("prompt done", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + }); + } } - slot.n_past += n_eval; } - image_idx++; - - llama_batch_clear(batch); - - // append prefix of next image - const auto json_prompt = - (image_idx >= (int)slot.images.size()) - ? slot.params.input_suffix - : // no more images, then process suffix prompt - (json)(slot.images[image_idx].prefix_prompt); + } - std::vector append_tokens = - Tokenize(json_prompt, false); // has next image - for (int i = 0; i < (int)append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, {slot.id}, true); - slot.n_past += 1; - } + if (batch.n_tokens == 0) { + all_slots_are_idle = true; + return true; } - return true; -} + for (int32_t i = 0; i < (int32_t)batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t)(batch.n_tokens - i)); + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, + 0, + 0, // unused + }; -int LlamaServerContext::SplitMultipromptTask(TaskServer& multiprompt_task) { - int prompt_count = multiprompt_task.data.at("prompt").size(); - assert(prompt_count > 1); + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_ERROR << "Failed to decode the batch: KV cache is full - try " + "increasing it via the context size: " + << "i = " << i << ", n_batch = " << n_batch + << ", ret = " << ret; + for (auto& slot : slots) { + slot.state = SlotState::kProcessing; + slot.command = SlotCommand::kNone; + slot.Release(); + // SendError(slot, + // "Input prompt is too big compared to KV size. Please " + // "try increasing KV size."); + } + break; // break loop of n_batch + } - int multitask_id = id_gen++; - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data["prompt"][i]; + LOG_WARN << "Failed to find free space in the KV cache, retrying with " + "smaller n_batch = " + << n_batch / 2; - // subtasks inherit everything else (infill mode, embedding mode, etc.) - subtask_ids[i] = - RequestCompletion(subtask_data, multiprompt_task.infill_mode, - multiprompt_task.embedding_mode, multitask_id); - } + // retry with half the batch size to try to find a free slot in the KV + // cache + n_batch /= 2; + i -= n_batch; + continue; + } - // queue up the multitask so we can track its subtask progression - AddMultiTask(multitask_id, subtask_ids); - return multitask_id; -} + for (auto& slot : slots) { + if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + continue; + } -void LlamaServerContext::ProcessTasks() { - std::unique_lock lock(mutex_tasks); - while (!queue_tasks.empty()) { - TaskServer task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); - switch (task.type) { - case TaskType::kCompletionTask: { - LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) { - LOG_WARN << "slot unavailable"; - // send error result - SendError(task, "slot unavailable"); - return; - } + // prompt evaluated for embedding + if (slot.embedding) { + SendEmbedding(slot); + slot.Release(); + slot.i_batch = -1; + return true; + } - if (task.data.contains("system_prompt")) { - ProcessSystemPromptData(task.data["system_prompt"]); - } + CompletionTokenOutput result; + const llama_token id = + llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); - slot->Reset(); + llama_sampling_accept(slot.ctx_sampling, ctx, id, true); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + if (slot.n_decoded == 1) { + slot.t_start_genereration = ggml_time_us(); + slot.t_prompt_processing = + (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; + } - if (!LaunchSlotWithData(slot, task.data)) { - // send error result - SendError(task, "internal_error"); - break; - } - } break; - case TaskType::kCancelTask: { // release slot linked with the task id - for (auto& slot : slots) { - if (slot.task_id == task.target_id) { - slot.Release(); - break; - } - } - } break; - } - } + llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), + slot.ctx_sampling->cur.size(), false}; + result.tok = id; - // remove finished multitasks from the queue of multitasks, and add the - // corresponding result to the result queue - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) { - if (queue_iterator->subtasks_remaining.empty()) { - // all subtasks done == multitask is done - TaskResult aggregate_result; - aggregate_result.id = queue_iterator->id; - aggregate_result.stop = true; - aggregate_result.error = false; + const int32_t n_probs = slot.sparams.n_probs; + if (slot.sparams.temp <= 0 && n_probs > 0) { + // for llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &cur_p); + } - // collect json results into one json result - std::vector result_jsons; - for (auto& subres : queue_iterator->results) { - result_jsons.push_back(subres.result_json); - aggregate_result.error = aggregate_result.error && subres.error; + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { + result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - aggregate_result.result_json = json{"results", result_jsons}; - std::lock_guard lock(mutex_results); - queue_results.push_back(aggregate_result); - condition_results.notify_all(); + if (!ProcessToken(result, slot)) { + slot.Release(); + slot.PrintTimings(); + SendFinalResponse(slot); + } - queue_iterator = queue_multitasks.erase(queue_iterator); - } else { - ++queue_iterator; + slot.i_batch = -1; } } + return true; } diff --git a/src/llama_server_context.h b/src/llama_server_context.h index 17082dd..38b7163 100644 --- a/src/llama_server_context.h +++ b/src/llama_server_context.h @@ -62,7 +62,7 @@ enum class StopType : uint8_t { kStopPartial, }; -enum class ModelType: uint8_t { kLlm = 0, kEmbedding }; +enum class ModelType : uint8_t { kLlm = 0, kEmbedding }; // TODO: reuse llama_detokenize template @@ -142,6 +142,7 @@ struct LlamaServerContext { std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results; + std::thread bgr_thread; ModelType model_type = ModelType::kLlm; ~LlamaServerContext(); @@ -152,11 +153,11 @@ struct LlamaServerContext { void KvCacheClear(); json GetModelProps(); int RequestCompletion(json data, bool infill, bool embedding, - int multitask_id); + int multitask_id); TaskResult NextResult(int task_id); void RequestCancel(int task_id); - bool UpdateSlots(); + void ReleaseResources(); private: std::vector Tokenize(const json& json_prompt, @@ -173,8 +174,8 @@ struct LlamaServerContext { void ProcessSystemPromptData(const json& sys_props); size_t FindStoppingStrings(const std::string& text, - const size_t last_token_size, - const StopType type, LlamaClientSlot& slot); + const size_t last_token_size, const StopType type, + LlamaClientSlot& slot); bool ProcessToken(CompletionTokenOutput& result, LlamaClientSlot& slot); @@ -201,4 +202,7 @@ struct LlamaServerContext { void ProcessTasks(); + void DoBackgroundTasks(); + + bool UpdateSlots(); }; diff --git a/src/llama_utils.h b/src/llama_utils.h index 9266389..6dd0daf 100644 --- a/src/llama_utils.h +++ b/src/llama_utils.h @@ -167,4 +167,27 @@ inline void ltrim(std::string& s) { })); }; +inline std::string GetModelId(const Json::Value& jsonBody) { + // First check if model exists in request + if (!jsonBody["model"].isNull()) { + return jsonBody["model"].asString(); + } else if (!jsonBody["model_alias"].isNull()) { + return jsonBody["model_alias"].asString(); + } + + // We check llama_model_path for loadmodel request + if (auto input = jsonBody["llama_model_path"]; !input.isNull()) { + auto s = input.asString(); + std::replace(s.begin(), s.end(), '\\', '/'); + auto const pos = s.find_last_of('/'); + // We only truncate the extension if file name has gguf extension + if (s.substr(s.find_last_of('.') + 1) == "gguf") { + return s.substr(pos + 1, s.find_last_of('.') - pos - 1); + } else { + return s.substr(pos + 1); + } + } + return {}; +} + } // namespace llama_utils \ No newline at end of file