From a8b2503bd045d6ee3234d1487cbf7088535f83fd Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 23 Dec 2024 11:23:16 +0700 Subject: [PATCH 01/16] feat: add filter compatible for engine variant api (#1819) --- docs/static/openapi/cortex.json | 78 +++++++++++++++++++++++++++++++ engine/controllers/engines.cc | 16 ++++++- engine/controllers/engines.h | 12 ++--- engine/services/engine_service.cc | 41 +++++++++++++++- engine/services/engine_service.h | 3 +- 5 files changed, 139 insertions(+), 11 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index a05f8b24e..479e300ce 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -2199,6 +2199,84 @@ "tags": ["Engines"] } }, + "/v1/engines/{name}/releases/{version}": { + "get": { + "summary": "List variants for a specific engine version", + "description": "Lists all available variants (builds) for a specific version of an engine. Variants can include different CPU architectures (AVX, AVX2, AVX512), GPU support (CUDA, Vulkan), and operating systems (Windows, Linux, macOS).", + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" + }, + "description": "The type of engine" + }, + { + "name": "version", + "in": "path", + "required": true, + "schema": { + "type": "string" + }, + "description": "The version of the engine" + }, + { + "name": "show", + "in": "query", + "required": false, + "schema": { + "type": "string", + "enum": ["all", "compatible"], + "default": "all" + }, + "description": "Filter the variants list. Use 'compatible' to show only variants compatible with the current system, or 'all' to show all available variants." + } + ], + "responses": { + "200": { + "description": "Successfully retrieved variants list", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the variant, including OS, architecture, and capabilities", + "example": "linux-amd64-avx-cuda-11-7" + }, + "created_at": { + "type": "string", + "format": "date-time", + "description": "Creation timestamp of the variant", + "example": "2024-11-13T04:51:16Z" + }, + "size": { + "type": "integer", + "description": "Size of the variant in bytes", + "example": 151224604 + }, + "download_count": { + "type": "integer", + "description": "Number of times this variant has been downloaded", + "example": 0 + } + } + } + } + } + } + } + }, + "tags": ["Engines"] + } + }, "/v1/engines/{name}/releases/latest": { "get": { "summary": "Get latest release", diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index a92d6805f..24e61ba4f 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -129,7 +129,8 @@ void Engines::GetEngineReleases( void Engines::GetEngineVariants( const HttpRequestPtr& req, std::function&& callback, - const std::string& engine, const std::string& version) const { + const std::string& engine, const std::string& version, + std::optional show) const { if (engine.empty()) { Json::Value res; res["message"] = "Engine name is required"; @@ -140,7 +141,18 @@ void Engines::GetEngineVariants( return; } - auto result = engine_service_->GetEngineVariants(engine, version); + auto show_value = show.value_or("all"); + if (show_value != "all" && show_value != "compatible") { + Json::Value res; + res["message"] = "Invalid show value. Can either be `all` or `compatible`"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->GetEngineVariants(engine, version, + show_value == "compatible"); auto normalize_version = string_utils::RemoveSubstring(version, "v"); Json::Value releases(Json::arrayValue); diff --git a/engine/controllers/engines.h b/engine/controllers/engines.h index b0a92b6c3..3ad9708e3 100644 --- a/engine/controllers/engines.h +++ b/engine/controllers/engines.h @@ -53,13 +53,11 @@ class Engines : public drogon::HttpController { METHOD_ADD(Engines::GetEngineReleases, "/{1}/releases", Get); ADD_METHOD_TO(Engines::GetEngineReleases, "/v1/engines/{1}/releases", Get); - METHOD_ADD(Engines::GetEngineVariants, "/{1}/releases/{2}", Get); - ADD_METHOD_TO(Engines::GetEngineVariants, "/v1/engines/{1}/releases/{2}", - Get); + ADD_METHOD_TO(Engines::GetEngineVariants, + "/v1/engines/{engine}/releases/{version}?show={show}", Get); - METHOD_ADD(Engines::GetLatestEngineVersion, "/{1}/releases/latest", Get); ADD_METHOD_TO(Engines::GetLatestEngineVersion, - "/v1/engines/{1}/releases/latest", Get); + "/v1/engines/{engine}/releases/latest", Get); METHOD_LIST_END @@ -83,8 +81,8 @@ class Engines : public drogon::HttpController { void GetEngineVariants(const HttpRequestPtr& req, std::function&& callback, - const std::string& engine, - const std::string& version) const; + const std::string& engine, const std::string& version, + std::optional show) const; void GetInstalledEngineVariants( const HttpRequestPtr& req, diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c8f4c180c..2ca06cb33 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -482,7 +482,8 @@ EngineService::GetEngineReleases(const std::string& engine) const { cpp::result, std::string> EngineService::GetEngineVariants(const std::string& engine, - const std::string& version) const { + const std::string& version, + bool filter_compatible_only) const { auto ne = NormalizeEngine(engine); auto engine_release = github_release_utils::GetReleaseByVersion("janhq", ne, version); @@ -506,6 +507,44 @@ EngineService::GetEngineVariants(const std::string& engine, return cpp::fail("No compatible variants found for " + engine); } + if (filter_compatible_only) { + auto system_info = system_info_utils::GetSystemInfo(); + compatible_variants.erase( + std::remove_if(compatible_variants.begin(), compatible_variants.end(), + [&system_info](const EngineVariant& variant) { + std::string name = variant.name; + std::transform(name.begin(), name.end(), name.begin(), + ::tolower); + + bool os_match = false; + if (system_info->os == "mac" && + name.find("mac") != std::string::npos) + os_match = true; + if (system_info->os == "windows" && + name.find("windows") != std::string::npos) + os_match = true; + if (system_info->os == "linux" && + name.find("linux") != std::string::npos) + os_match = true; + + bool arch_match = false; + if (system_info->arch == "arm64" && + name.find("arm64") != std::string::npos) + arch_match = true; + if (system_info->arch == "amd64" && + name.find("amd64") != std::string::npos) + arch_match = true; + + return !(os_match && arch_match); + }), + compatible_variants.end()); + + if (compatible_variants.empty()) { + return cpp::fail("No compatible variants found for system " + + system_info->os + "/" + system_info->arch); + } + } + return compatible_variants; } diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 527123cb5..ab58e0e4a 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -101,7 +101,8 @@ class EngineService : public EngineServiceI { const std::string& engine) const; cpp::result, std::string> GetEngineVariants( - const std::string& engine, const std::string& version) const; + const std::string& engine, const std::string& version, + bool filter_compatible_only = false) const; cpp::result SetDefaultEngineVariant( const std::string& engine, const std::string& version, From e408f785ac29d42aa20974eae639d4199687be6e Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 23 Dec 2024 11:46:08 +0700 Subject: [PATCH 02/16] feat: rendering chat_template (#1814) --- engine/cli/commands/chat_completion_cmd.cc | 13 +- engine/common/model_metadata.h | 29 + engine/common/tokenizer.h | 72 + engine/controllers/files.cc | 17 +- engine/controllers/server.cc | 9 +- engine/main.cc | 1 + engine/services/engine_service.h | 20 +- engine/services/inference_service.cc | 42 +- engine/services/inference_service.h | 9 +- engine/services/model_service.cc | 49 +- engine/services/model_service.h | 18 +- engine/test/components/test_gguf_parser.cc | 245 +- engine/utils/chat-template.hpp | 137 + engine/utils/cortex_utils.h | 34 +- engine/utils/gguf_metadata_reader.h | 420 +++ engine/utils/jinja_utils.h | 27 + engine/utils/minja.hpp | 3428 ++++++++++++++++++++ 17 files changed, 4402 insertions(+), 168 deletions(-) create mode 100644 engine/common/model_metadata.h create mode 100644 engine/common/tokenizer.h create mode 100644 engine/utils/chat-template.hpp create mode 100644 engine/utils/gguf_metadata_reader.h create mode 100644 engine/utils/jinja_utils.h create mode 100644 engine/utils/minja.hpp diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index 0067b1c08..77d222176 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -50,7 +50,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { return data_length; } - } // namespace void ChatCompletionCmd::Exec(const std::string& host, int port, @@ -103,7 +102,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, return; } - std::string url = "http://" + address + "/v1/chat/completions"; + auto url = "http://" + address + "/v1/chat/completions"; curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_POST, 1L); @@ -151,9 +150,10 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, json_data["model"] = model_handle; json_data["stream"] = true; - std::string json_payload = json_data.toStyledString(); - - curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str()); + auto json_str = json_data.toStyledString(); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_str.length()); + curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L); std::string ai_chat; StreamingCallback callback; @@ -161,8 +161,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); curl_easy_setopt(curl, CURLOPT_WRITEDATA, &callback); - - CURLcode res = curl_easy_perform(curl); + auto res = curl_easy_perform(curl); if (res != CURLE_OK) { CLI_LOG("CURL request failed: " << curl_easy_strerror(res)); diff --git a/engine/common/model_metadata.h b/engine/common/model_metadata.h new file mode 100644 index 000000000..af733665b --- /dev/null +++ b/engine/common/model_metadata.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include "common/tokenizer.h" + +struct ModelMetadata { + uint32_t version; + uint64_t tensor_count; + uint64_t metadata_kv_count; + std::shared_ptr tokenizer; + + std::string ToString() const { + std::ostringstream ss; + ss << "ModelMetadata {\n" + << "version: " << version << "\n" + << "tensor_count: " << tensor_count << "\n" + << "metadata_kv_count: " << metadata_kv_count << "\n" + << "tokenizer: "; + + if (tokenizer) { + ss << "\n" << tokenizer->ToString(); + } else { + ss << "null"; + } + + ss << "\n}"; + return ss.str(); + } +}; diff --git a/engine/common/tokenizer.h b/engine/common/tokenizer.h new file mode 100644 index 000000000..33367f06b --- /dev/null +++ b/engine/common/tokenizer.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +struct Tokenizer { + std::string eos_token = ""; + bool add_eos_token = true; + + std::string bos_token = ""; + bool add_bos_token = true; + + std::string unknown_token = ""; + std::string padding_token = ""; + + std::string chat_template = ""; + + bool add_generation_prompt = true; + + // Helper function for common fields + std::string BaseToString() const { + std::ostringstream ss; + ss << "eos_token: \"" << eos_token << "\"\n" + << "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n" + << "bos_token: \"" << bos_token << "\"\n" + << "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n" + << "unknown_token: \"" << unknown_token << "\"\n" + << "padding_token: \"" << padding_token << "\"\n" + << "chat_template: \"" << chat_template << "\"\n" + << "add_generation_prompt: " + << (add_generation_prompt ? "true" : "false") << "\""; + return ss.str(); + } + + virtual ~Tokenizer() = default; + + virtual std::string ToString() = 0; +}; + +struct GgufTokenizer : public Tokenizer { + std::string pre = ""; + + ~GgufTokenizer() override = default; + + std::string ToString() override { + std::ostringstream ss; + ss << "GgufTokenizer {\n"; + // Add base class members + ss << BaseToString() << "\n"; + // Add derived class members + ss << "pre: \"" << pre << "\"\n"; + ss << "}"; + return ss.str(); + } +}; + +struct SafeTensorTokenizer : public Tokenizer { + bool add_prefix_space = true; + + ~SafeTensorTokenizer() = default; + + std::string ToString() override { + std::ostringstream ss; + ss << "SafeTensorTokenizer {\n"; + // Add base class members + ss << BaseToString() << "\n"; + // Add derived class members + ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n"; + ss << "}"; + return ss.str(); + } +}; diff --git a/engine/controllers/files.cc b/engine/controllers/files.cc index e0cd502f4..ed37967b2 100644 --- a/engine/controllers/files.cc +++ b/engine/controllers/files.cc @@ -216,10 +216,8 @@ void Files::RetrieveFileContent( return; } - auto [buffer, size] = std::move(res.value()); - auto resp = HttpResponse::newHttpResponse(); - resp->setBody(std::string(buffer.get(), size)); - resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + auto resp = + cortex_utils::CreateCortexContentResponse(std::move(res.value())); callback(resp); } else { if (!msg_res->rel_path.has_value()) { @@ -243,10 +241,8 @@ void Files::RetrieveFileContent( return; } - auto [buffer, size] = std::move(content_res.value()); - auto resp = HttpResponse::newHttpResponse(); - resp->setBody(std::string(buffer.get(), size)); - resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + auto resp = cortex_utils::CreateCortexContentResponse( + std::move(content_res.value())); callback(resp); } } @@ -261,9 +257,6 @@ void Files::RetrieveFileContent( return; } - auto [buffer, size] = std::move(res.value()); - auto resp = HttpResponse::newHttpResponse(); - resp->setBody(std::string(buffer.get(), size)); - resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value())); callback(resp); } diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 4c6bcaf82..19842bcdb 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -3,7 +3,6 @@ #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" #include "utils/function_calling/common.h" -#include "utils/http_util.h" using namespace inferences; @@ -27,6 +26,14 @@ void server::ChatCompletion( std::function&& callback) { LOG_DEBUG << "Start chat completion"; auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } bool is_stream = (*json_body).get("stream", false).asBool(); auto model_id = (*json_body).get("model", "invalid_model").asString(); auto engine_type = [this, &json_body]() -> std::string { diff --git a/engine/main.cc b/engine/main.cc index 5cc6c740e..ddf1eefd8 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -159,6 +159,7 @@ void RunServer(std::optional host, std::optional port, auto model_src_svc = std::make_shared(); auto model_service = std::make_shared( download_service, inference_svc, engine_service); + inference_svc->SetModelService(model_service); auto file_watcher_srv = std::make_shared( model_dir_path.string(), model_service); diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index ab58e0e4a..8ead4f6d6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -17,7 +16,6 @@ #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" #include "utils/dylib_path_manager.h" -#include "utils/engine_constants.h" #include "utils/github_release_utils.h" #include "utils/result.hpp" #include "utils/system_info_utils.h" @@ -48,10 +46,6 @@ class EngineService : public EngineServiceI { struct EngineInfo { std::unique_ptr dl; EngineV engine; -#if defined(_WIN32) - DLL_DIRECTORY_COOKIE cookie; - DLL_DIRECTORY_COOKIE cuda_cookie; -#endif }; std::mutex engines_mutex_; @@ -106,21 +100,23 @@ class EngineService : public EngineServiceI { cpp::result SetDefaultEngineVariant( const std::string& engine, const std::string& version, - const std::string& variant); + const std::string& variant) override; cpp::result GetDefaultEngineVariant( - const std::string& engine); + const std::string& engine) override; cpp::result, std::string> - GetInstalledEngineVariants(const std::string& engine) const; + GetInstalledEngineVariants(const std::string& engine) const override; cpp::result GetLoadedEngine( const std::string& engine_name); std::vector GetLoadedEngines(); - cpp::result LoadEngine(const std::string& engine_name); - cpp::result UnloadEngine(const std::string& engine_name); + cpp::result LoadEngine( + const std::string& engine_name) override; + cpp::result UnloadEngine( + const std::string& engine_name) override; cpp::result GetLatestEngineVersion(const std::string& engine) const; @@ -138,7 +134,7 @@ class EngineService : public EngineServiceI { cpp::result GetEngineByNameAndVariant( const std::string& engine_name, - const std::optional variant = std::nullopt); + const std::optional variant = std::nullopt) override; cpp::result UpsertEngine( const std::string& engine_name, const std::string& type, diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 91cb277dc..08107562b 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -2,6 +2,7 @@ #include #include "utils/engine_constants.h" #include "utils/function_calling/common.h" +#include "utils/jinja_utils.h" namespace services { cpp::result InferenceService::HandleChatCompletion( @@ -24,6 +25,45 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } + { + auto model_id = json_body->get("model", "").asString(); + if (!model_id.empty()) { + if (auto model_service = model_service_.lock()) { + auto metadata_ptr = model_service->GetCachedModelMetadata(model_id); + if (metadata_ptr != nullptr && + !metadata_ptr->tokenizer->chat_template.empty()) { + auto tokenizer = metadata_ptr->tokenizer; + auto messages = (*json_body)["messages"]; + Json::Value messages_jsoncpp(Json::arrayValue); + for (auto message : messages) { + messages_jsoncpp.append(message); + } + + Json::Value tools(Json::arrayValue); + Json::Value template_data_json; + template_data_json["messages"] = messages_jsoncpp; + // template_data_json["tools"] = tools; + + auto prompt_result = jinja::RenderTemplate( + tokenizer->chat_template, template_data_json, + tokenizer->bos_token, tokenizer->eos_token, + tokenizer->add_bos_token, tokenizer->add_eos_token, + tokenizer->add_generation_prompt); + if (prompt_result.has_value()) { + (*json_body)["prompt"] = prompt_result.value(); + Json::Value stops(Json::arrayValue); + stops.append(tokenizer->eos_token); + (*json_body)["stop"] = stops; + } else { + CTL_ERR("Failed to render prompt: " + prompt_result.error()); + } + } + } + } + } + + CTL_INF("Json body inference: " + json_body->toStyledString()); + auto cb = [q, tool_choice](Json::Value status, Json::Value res) { if (!tool_choice.isNull()) { res["tool_choice"] = tool_choice; @@ -297,4 +337,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } -} // namespace services \ No newline at end of file +} // namespace services diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index b417fa14a..54bc9dc29 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -4,9 +4,11 @@ #include #include #include "services/engine_service.h" +#include "services/model_service.h" #include "utils/result.hpp" -#include "extensions/remote-engine/remote_engine.h" + namespace services { + // Status and result using InferResult = std::pair; @@ -58,7 +60,12 @@ class InferenceService { bool HasFieldInReq(std::shared_ptr json_body, const std::string& field); + void SetModelService(std::shared_ptr model_service) { + model_service_ = model_service; + } + private: std::shared_ptr engine_service_; + std::weak_ptr model_service_; }; } // namespace services diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index ce83152c4..0d909b61f 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -9,10 +9,11 @@ #include "config/yaml_config.h" #include "database/models.h" #include "hardware_service.h" +#include "services/inference_service.h" #include "utils/cli_selection_utils.h" -#include "utils/cortex_utils.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" +#include "utils/gguf_metadata_reader.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" #include "utils/result.hpp" @@ -877,6 +878,18 @@ cpp::result ModelService::StartModel( auto data = std::get<1>(ir); if (status == drogon::k200OK) { + // start model successfully, we store the metadata so we can use + // for each inference + auto metadata_res = GetModelMetadata(model_handle); + if (metadata_res.has_value()) { + loaded_model_metadata_map_.emplace(model_handle, + std::move(metadata_res.value())); + CTL_INF("Successfully stored metadata for model " << model_handle); + } else { + CTL_WRN("Failed to get metadata for model " << model_handle << ": " + << metadata_res.error()); + } + return StartModelResult{.success = true, .warning = may_fallback_res.value()}; } else if (status == drogon::k409Conflict) { @@ -929,6 +942,8 @@ cpp::result ModelService::StopModel( if (bypass_check) { bypass_stop_check_set_.erase(model_handle); } + loaded_model_metadata_map_.erase(model_handle); + CTL_INF("Removed metadata for model " << model_handle); return true; } else { CTL_ERR("Model failed to stop with status code: " << status); @@ -1208,3 +1223,35 @@ ModelService::MayFallbackToCpu(const std::string& model_path, int ngl, return warning; } + +cpp::result, std::string> +ModelService::GetModelMetadata(const std::string& model_id) const { + if (model_id.empty()) { + return cpp::fail("Model ID can't be empty"); + } + + auto model_config = GetDownloadedModel(model_id); + if (!model_config.has_value()) { + return cpp::fail("Can't get model config for " + model_id); + } + + if (model_config->files.empty()) { + return cpp::fail("Model has no actual file. Might not be a local model!"); + } + // TODO: handle the case we have multiple files + auto file = model_config->files[0]; + + auto model_metadata_res = cortex_utils::ReadGgufMetadata( + file_manager_utils::ToAbsoluteCortexDataPath( + std::filesystem::path(file))); + if (!model_metadata_res.has_value()) { + CTL_ERR("Failed to read metadata: " + model_metadata_res.error()); + return cpp::fail("Failed to read metadata: " + model_metadata_res.error()); + } + return std::move(*model_metadata_res); +} + +std::shared_ptr ModelService::GetCachedModelMetadata( + const std::string& model_id) const { + return loaded_model_metadata_map_.at(model_id); +} diff --git a/engine/services/model_service.h b/engine/services/model_service.h index e2638fd1f..8b24b3421 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -4,11 +4,15 @@ #include #include #include "common/engine_servicei.h" +#include "common/model_metadata.h" #include "config/model_config.h" #include "services/download_service.h" -#include "services/inference_service.h" #include "utils/hardware/gguf/gguf_file_estimate.h" +namespace services { +class InferenceService; +} + struct ModelPullInfo { std::string id; std::string default_branch; @@ -101,6 +105,12 @@ class ModelService { const std::string& model_handle, const std::string& kv_cache = "f16", int n_batch = 2048, int n_ubatch = 2048); + cpp::result, std::string> GetModelMetadata( + const std::string& model_id) const; + + std::shared_ptr GetCachedModelMetadata( + const std::string& model_id) const; + private: /** * Handle downloading model which have following pattern: author/model_name @@ -124,4 +134,10 @@ class ModelService { std::shared_ptr inference_svc_; std::unordered_set bypass_stop_check_set_; std::shared_ptr engine_svc_ = nullptr; + + /** + * Store the chat template of loaded model. + */ + std::unordered_map> + loaded_model_metadata_map_; }; diff --git a/engine/test/components/test_gguf_parser.cc b/engine/test/components/test_gguf_parser.cc index 6c5c61486..df860a1d7 100644 --- a/engine/test/components/test_gguf_parser.cc +++ b/engine/test/components/test_gguf_parser.cc @@ -1,12 +1,10 @@ -#include "gtest/gtest.h" -#include "config/gguf_parser.h" -#include "config/yaml_config.h" -#include #include -#include -#include +#include #include -#include +#include +#include "config/gguf_parser.h" +#include "config/yaml_config.h" +#include "gtest/gtest.h" #ifdef _WIN32 #include @@ -15,144 +13,145 @@ #endif class GGUFParserTest : public ::testing::Test { -protected: - void SetUp() override { - gguf_handler = std::make_unique(); - yaml_handler = std::make_unique< config::YamlHandler>(); - } + protected: + void SetUp() override { + gguf_handler = std::make_unique(); + yaml_handler = std::make_unique(); + } - void TearDown() override { - } + void TearDown() override {} - std::unique_ptr gguf_handler; - std::unique_ptr yaml_handler; - - std::string getTempFilePath(const std::string& prefix, const std::string& extension) { - #ifdef _WIN32 - char temp_path[MAX_PATH]; - char file_name[MAX_PATH]; - GetTempPathA(MAX_PATH, temp_path); - GetTempFileNameA(temp_path, prefix.c_str(), 0, file_name); - std::string path(file_name); - DeleteFileA(file_name); // Delete the file created by GetTempFileNameA - return path + extension; - #else - std::string path = "/tmp/" + prefix + "XXXXXX" + extension; - char* temp = strdup(path.c_str()); - int fd = mkstemps(temp, extension.length()); - if (fd == -1) { - free(temp); - throw std::runtime_error("Failed to create temporary file"); - } - close(fd); - std::string result(temp); - free(temp); - return result; - #endif + std::unique_ptr gguf_handler; + std::unique_ptr yaml_handler; + + std::string getTempFilePath(const std::string& prefix, + const std::string& extension) { +#ifdef _WIN32 + char temp_path[MAX_PATH]; + char file_name[MAX_PATH]; + GetTempPathA(MAX_PATH, temp_path); + GetTempFileNameA(temp_path, prefix.c_str(), 0, file_name); + std::string path(file_name); + DeleteFileA(file_name); // Delete the file created by GetTempFileNameA + return path + extension; +#else + std::string path = "/tmp/" + prefix + "XXXXXX" + extension; + char* temp = strdup(path.c_str()); + int fd = mkstemps(temp, extension.length()); + if (fd == -1) { + free(temp); + throw std::runtime_error("Failed to create temporary file"); } + close(fd); + std::string result(temp); + free(temp); + return result; +#endif + } - std::string createMockGGUFFile() { - std::string gguf_path = getTempFilePath("mock_tinyllama-model", ".gguf"); - std::ofstream file(gguf_path, std::ios::binary); + std::string createMockGGUFFile() { + std::string gguf_path = getTempFilePath("mock_tinyllama-model", ".gguf"); + std::ofstream file(gguf_path, std::ios::binary); - if (!file.is_open()) { - throw std::runtime_error("Failed to create mock GGUF file"); - } + if (!file.is_open()) { + throw std::runtime_error("Failed to create mock GGUF file"); + } - try { - // GGUF magic number - uint32_t magic = 0x46554747; - file.write(reinterpret_cast(&magic), sizeof(magic)); - - // Version - uint32_t version = 2; - file.write(reinterpret_cast(&version), sizeof(version)); - - // Tensor count (not important for this test) - uint64_t tensor_count = 0; - file.write(reinterpret_cast(&tensor_count), sizeof(tensor_count)); - - // Metadata key-value count - uint64_t kv_count = 2; - file.write(reinterpret_cast(&kv_count), sizeof(kv_count)); - - // Helper function to write a string - auto writeString = [&file](const std::string& str) { - uint64_t length = str.length(); - file.write(reinterpret_cast(&length), sizeof(length)); - file.write(str.c_str(), length); - }; - - // Helper function to write a key-value pair - auto writeKV = [&](const std::string& key, uint32_t type, const auto& value) { - writeString(key); - file.write(reinterpret_cast(&type), sizeof(type)); - if constexpr (std::is_same_v) { - writeString(value); - } else { - file.write(reinterpret_cast(&value), sizeof(value)); - } - }; - - // Write metadata - writeKV("general.name", 8, std::string("tinyllama 1B")); - writeKV("llama.context_length", 4, uint32_t(4096)); - - file.close(); - - } catch (const std::exception& e) { - file.close(); - std::remove(gguf_path.c_str()); - throw std::runtime_error(std::string("Failed to write mock GGUF file: ") + e.what()); + try { + // GGUF magic number + uint32_t magic = 0x46554747; + file.write(reinterpret_cast(&magic), sizeof(magic)); + + // Version + uint32_t version = 2; + file.write(reinterpret_cast(&version), sizeof(version)); + + // Tensor count (not important for this test) + uint64_t tensor_count = 0; + file.write(reinterpret_cast(&tensor_count), sizeof(tensor_count)); + + // Metadata key-value count + uint64_t kv_count = 2; + file.write(reinterpret_cast(&kv_count), sizeof(kv_count)); + + // Helper function to write a string + auto writeString = [&file](const std::string& str) { + uint64_t length = str.length(); + file.write(reinterpret_cast(&length), sizeof(length)); + file.write(str.c_str(), length); + }; + + // Helper function to write a key-value pair + auto writeKV = [&](const std::string& key, uint32_t type, + const auto& value) { + writeString(key); + file.write(reinterpret_cast(&type), sizeof(type)); + if constexpr (std::is_same_v) { + writeString(value); + } else { + file.write(reinterpret_cast(&value), sizeof(value)); } + }; - return gguf_path; + // Write metadata + writeKV("general.name", 8, std::string("tinyllama 1B")); + writeKV("llama.context_length", 4, uint32_t(4096)); + + file.close(); + + } catch (const std::exception& e) { + file.close(); + std::remove(gguf_path.c_str()); + throw std::runtime_error(std::string("Failed to write mock GGUF file: ") + + e.what()); } + + return gguf_path; + } }; TEST_F(GGUFParserTest, ParseMockTinyLlamaModel) { - std::string gguf_path; - std::string yaml_path; - try { - // Create a mock GGUF file - gguf_path = createMockGGUFFile(); + std::string gguf_path; + std::string yaml_path; + try { + // Create a mock GGUF file + gguf_path = createMockGGUFFile(); - // Parse the GGUF file - gguf_handler->Parse(gguf_path); + // Parse the GGUF file + gguf_handler->Parse(gguf_path); - const config::ModelConfig& gguf_config = gguf_handler->GetModelConfig(); + const config::ModelConfig& gguf_config = gguf_handler->GetModelConfig(); - // Load the expected configuration from YAML - std::string yaml_content = R"( + // Load the expected configuration from YAML + std::string yaml_content = R"( name: tinyllama-1B ctx_len: 4096 )"; - yaml_path = getTempFilePath("expected_config", ".yaml"); - std::ofstream yaml_file(yaml_path); - yaml_file << yaml_content; - yaml_file.close(); + yaml_path = getTempFilePath("expected_config", ".yaml"); + std::ofstream yaml_file(yaml_path); + yaml_file << yaml_content; + yaml_file.close(); - yaml_handler->ModelConfigFromFile(yaml_path); + yaml_handler->ModelConfigFromFile(yaml_path); - const config::ModelConfig& yaml_config = yaml_handler->GetModelConfig(); + const config::ModelConfig& yaml_config = yaml_handler->GetModelConfig(); - // Compare GGUF parsed config with YAML config - EXPECT_EQ(gguf_config.name, yaml_config.name); - EXPECT_EQ(gguf_config.ctx_len, yaml_config.ctx_len); + // Compare GGUF parsed config with YAML config + EXPECT_EQ(gguf_config.name, yaml_config.name); + EXPECT_EQ(gguf_config.ctx_len, yaml_config.ctx_len); - // Clean up - std::remove(gguf_path.c_str()); - std::remove(yaml_path.c_str()); + // Clean up + std::remove(gguf_path.c_str()); + std::remove(yaml_path.c_str()); + } catch (const std::exception& e) { + // If an exception was thrown, make sure to clean up the files + if (!gguf_path.empty()) { + std::remove(gguf_path.c_str()); } - catch (const std::exception& e) { - // If an exception was thrown, make sure to clean up the files - if (!gguf_path.empty()) { - std::remove(gguf_path.c_str()); - } - if (!yaml_path.empty()) { - std::remove(yaml_path.c_str()); - } - FAIL() << "Exception thrown: " << e.what(); + if (!yaml_path.empty()) { + std::remove(yaml_path.c_str()); } -} \ No newline at end of file + FAIL() << "Exception thrown: " << e.what(); + } +} diff --git a/engine/utils/chat-template.hpp b/engine/utils/chat-template.hpp new file mode 100644 index 000000000..309dd8e97 --- /dev/null +++ b/engine/utils/chat-template.hpp @@ -0,0 +1,137 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include "utils/minja.hpp" + +using json = nlohmann::ordered_json; + +namespace minja { + +class chat_template { + public: + private: + bool _supports_tools = true; + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool _requires_object_arguments = false; + bool _supports_system_role = true; + std::string _source; + std::string _bos_token; + std::string _eos_token; + std::shared_ptr _template_root; + + public: + chat_template(const std::string& source, const std::string& bos_token, + const std::string& eos_token) + : _source(source), _bos_token(bos_token), _eos_token(eos_token) { + _supports_tools = source.find("tools") != std::string::npos; + _requires_object_arguments = + source.find("tool_call.arguments | items") != std::string::npos || + source.find("tool_call.arguments | tojson") != std::string::npos; + _supports_system_role = + source.find("System role not supported") == std::string::npos; + + _template_root = + minja::Parser::parse(_source, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + } + + const std::string& source() const { return _source; } + bool supports_tools() const { return _supports_tools; } + + std::string apply(const nlohmann::ordered_json& messages, + const nlohmann::ordered_json& tools, + bool add_generation_prompt, + const nlohmann::ordered_json& extra_context = + nlohmann::ordered_json()) const { + auto actual_messages = messages; + + // First, "fix" messages so they have a chance to be rendered correctly by the template + + if (_requires_object_arguments || !_supports_system_role) { + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + actual_messages.push_back({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (auto& message : actual_messages) { + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error( + "message must have 'role' and 'content' fields: " + + message.dump()); + } + std::string role = message.at("role"); + + if (!message["content"].is_null() && !_supports_system_role) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) + pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = + pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + if (_requires_object_arguments && message.contains("tool_calls")) { + for (auto& tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto& function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } + } + } + } + flush_sys(); + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", _bos_token}, + {"eos_token", _eos_token}, + })); + + if (!tools.is_null()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + if (!extra_context.is_null()) { + for (auto& kv : extra_context.items()) { + minja::Value val(kv.value()); + context->set(kv.key(), val); + } + } + + return _template_root->render(context); + } +}; + +} // namespace minja diff --git a/engine/utils/cortex_utils.h b/engine/utils/cortex_utils.h index 4d0a956a9..f58fcfe8f 100644 --- a/engine/utils/cortex_utils.h +++ b/engine/utils/cortex_utils.h @@ -2,16 +2,10 @@ #include #include #include -#include #include -#include -#include #include -#include -#include -#include -#include #include +#include #if defined(__linux__) #include #include @@ -69,6 +63,30 @@ inline drogon::HttpResponsePtr CreateCortexHttpJsonResponse( return res; }; +inline drogon::HttpResponsePtr CreateCortexContentResponse( + std::pair, size_t> content) { + auto [buffer, size] = std::move(content); + auto resp = drogon::HttpResponse::newHttpResponse(); + resp->setBody(std::string(buffer.get(), size)); + resp->setContentTypeCode(drogon::CT_APPLICATION_OCTET_STREAM); + +#if defined(_WIN32) + resp->addHeader("date", GetDateRFC1123()); +#endif + return resp; +} + +inline drogon::HttpResponsePtr CreateTextPlainResponse( + const std::string& text) { + auto resp = drogon::HttpResponse::newHttpResponse(); + resp->setBody(text); + resp->setContentTypeCode(drogon::CT_TEXT_PLAIN); +#if defined(_WIN32) + resp->addHeader("date", GetDateRFC1123()); +#endif + return resp; +} + inline drogon::HttpResponsePtr CreateCortexStreamResponse( const std::function& callback, const std::string& attachmentFileName = "") { @@ -80,8 +98,6 @@ inline drogon::HttpResponsePtr CreateCortexStreamResponse( return res; } - - #if defined(_WIN32) inline std::string GetCurrentPath() { char path[MAX_PATH]; diff --git a/engine/utils/gguf_metadata_reader.h b/engine/utils/gguf_metadata_reader.h new file mode 100644 index 000000000..838177505 --- /dev/null +++ b/engine/utils/gguf_metadata_reader.h @@ -0,0 +1,420 @@ +#pragma once + +#include +#include +#include +#include +#include "common/model_metadata.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +/** + * Parsing the GGUF metadata. + * + * Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md + */ +namespace cortex_utils { +namespace { +// present in the first 4 bytes of a GGUF file +constexpr uint32_t GGUF_MAGIC_NUMBER = 1179993927; + +constexpr static auto GGUF_VERSION_LENGTH = 4; +constexpr static auto TENSOR_COUNT_LENGTH = 8; +constexpr static auto METADATA_KV_COUNT = 8; + +constexpr static auto TOKEN_LIST_KEY = "tokenizer.ggml.tokens"; +constexpr static auto BOS_ID_KEY = "tokenizer.ggml.bos_token_id"; +constexpr static auto EOS_ID_KEY = "tokenizer.ggml.eos_token_id"; +constexpr static auto UNK_ID_KEY = "tokenizer.ggml.unknown_token_id"; +constexpr static auto PADDING_ID_KEY = "tokenizer.ggml.padding_token_id"; + +constexpr static auto CHAT_TEMPLATE_ID_KEY = "tokenizer.chat_template"; +constexpr static auto ADD_BOS_TOKEN_KEY = "tokenizer.ggml.add_bos_token"; +constexpr static auto ADD_EOS_TOKEN_KEY = "tokenizer.ggml.add_eos_token"; +const std::vector kSpecialTokenIds{BOS_ID_KEY, EOS_ID_KEY, + UNK_ID_KEY, PADDING_ID_KEY}; + +struct MetadataArrayElement; + +// clang-format off +using MetadataValue = std::variant< + uint8_t, int8_t, + uint16_t, int16_t, + uint32_t, int32_t, + uint64_t, int64_t, + float, double, + bool, std::string, + std::vector +>; + +// clang-format on + +struct MetadataArrayElement { + MetadataValue value; + + // Add constructors for different types + MetadataArrayElement(uint8_t v) : value(v) {} + MetadataArrayElement(int8_t v) : value(v) {} + MetadataArrayElement(uint16_t v) : value(v) {} + MetadataArrayElement(int16_t v) : value(v) {} + MetadataArrayElement(uint32_t v) : value(v) {} + MetadataArrayElement(int32_t v) : value(v) {} + MetadataArrayElement(uint64_t v) : value(v) {} + MetadataArrayElement(int64_t v) : value(v) {} + MetadataArrayElement(float v) : value(v) {} + MetadataArrayElement(double v) : value(v) {} + MetadataArrayElement(bool v) : value(v) {} + MetadataArrayElement(const std::string& v) : value(v) {} + MetadataArrayElement(std::string&& v) : value(std::move(v)) {} + + MetadataArrayElement(MetadataValue&& v) : value(std::move(v)) {} +}; + +struct MetadataValueResult { + size_t bytes_read; + MetadataValue value; + + template + MetadataValueResult(size_t br, T&& val) + : bytes_read(br), value(std::forward(val)) {} +}; + +std::pair ReadString(std::ifstream& file) { + uint64_t length; + file.read(reinterpret_cast(&length), sizeof(uint64_t)); + + if (!file) { + throw std::runtime_error("Failed to read string length"); + } + + if (length > 1024 * 1024 * 1024) { + throw std::runtime_error("String length too large: " + + std::to_string(length)); + } + + std::string value(length, '\0'); + file.read(value.data(), length); + + if (!file) { + throw std::runtime_error("Failed to read string content of length " + + std::to_string(length)); + } + + return {8 + length, value}; +} + +inline MetadataValueResult ReadMetadataValue(uint32_t value_type, + std::ifstream& file, + const std::string& key) { + switch (value_type) { + case 0: { // uint8 + uint8_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint8_t), value}; + } + case 1: { // int8 + int8_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int8_t), value}; + } + case 2: { // uint16 + uint16_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint16_t), value}; + } + case 3: { // int16 + int16_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int16_t), value}; + } + case 4: { // uint32 + uint32_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint32_t), value}; + } + case 5: { // int32 + int32_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int32_t), value}; + } + case 6: { // float32 + float value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(float), value}; + } + case 7: { // bool + bool value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(bool), value}; + } + case 8: { // string + auto [length, value] = ReadString(file); + return {length, value}; + } + case 9: { // array + uint32_t array_type; + file.read(reinterpret_cast(&array_type), sizeof(uint32_t)); + + uint64_t array_length; + file.read(reinterpret_cast(&array_length), sizeof(uint64_t)); + + size_t bytes_read = 12; // 4 for type + 8 for length + + std::vector array_values_string; + std::vector array_values_float; + + for (uint64_t i = 0; i < array_length; ++i) { + auto result = ReadMetadataValue(array_type, file, + key + "[" + std::to_string(i) + "]"); + bytes_read += result.bytes_read; + + if (array_type == 8) { + array_values_string.push_back(std::get(result.value)); + } else { + float float_value; + switch (result.value.index()) { + case 0: + float_value = static_cast(std::get(result.value)); + break; + case 1: + float_value = static_cast(std::get(result.value)); + break; + case 2: + float_value = + static_cast(std::get(result.value)); + break; + case 3: + float_value = static_cast(std::get(result.value)); + break; + case 4: + float_value = + static_cast(std::get(result.value)); + break; + case 5: + float_value = static_cast(std::get(result.value)); + break; + case 6: + float_value = + static_cast(std::get(result.value)); + break; + case 7: + float_value = static_cast(std::get(result.value)); + break; + case 8: + float_value = std::get(result.value); + break; + case 9: + float_value = static_cast(std::get(result.value)); + break; + case 10: + float_value = static_cast(std::get(result.value)); + break; + default: + throw std::runtime_error( + "Unexpected type in array element conversion"); + } + array_values_float.push_back(float_value); + } + } + + if (!array_values_string.empty()) { + std::vector result; + result.reserve(array_values_string.size()); + for (const auto& str : array_values_string) { + result.emplace_back(str); + } + return {bytes_read, std::move(result)}; + } else { + std::vector result; + result.reserve(array_values_float.size()); + for (float val : array_values_float) { + result.emplace_back(val); + } + return {bytes_read, std::move(result)}; + } + } + + case 10: { // uint64 + uint64_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(uint64_t), value}; + } + case 11: { // int64 + int64_t value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(int64_t), value}; + } + case 12: { // float64/double + double value; + file.read(reinterpret_cast(&value), sizeof(value)); + return {sizeof(double), value}; + } + default: + throw std::runtime_error("Unknown value type: " + + std::to_string(value_type) + " for key: " + key); + } +} + +void PrintMetadataValue(const std::string& key, const MetadataValue& value) { + std::ostringstream oss; + oss << "Key: " << key << " = "; + + switch (value.index()) { + case 0: // uint8_t + oss << "uint8: " << static_cast(std::get(value)); + break; + case 1: // int8_t + oss << "int8: " << static_cast(std::get(value)); + break; + case 2: // uint16_t + oss << "uint16: " << std::get(value); + break; + case 3: // int16_t + oss << "int16: " << std::get(value); + break; + case 4: // uint32_t + oss << "uint32: " << std::get(value); + break; + case 5: // int32_t + oss << "int32: " << std::get(value); + break; + case 6: // uint64_t + oss << "uint64: " << std::get(value); + break; + case 7: // int64_t + oss << "int64: " << std::get(value); + break; + case 8: // float + oss << "float: " << std::get(value); + break; + case 9: // double + oss << "double: " << std::get(value); + break; + case 10: // bool + oss << "bool: " << (std::get(value) ? "true" : "false"); + break; + case 11: // string + oss << "string: " << std::get(value); + break; + case 12: { // vector + const auto& arr = std::get>(value); + oss << "array[" << arr.size() << "]: {"; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + oss << ", "; + std::ostringstream key_oss; + key_oss << key << "[" << i << "]"; + PrintMetadataValue(key_oss.str(), arr[i].value); + } + oss << "}"; + break; + } + } + + CTL_INF(oss.str()); +} +} // namespace + +inline cpp::result, std::string> +ReadGgufMetadata(const std::filesystem::path& path) { + if (!std::filesystem::exists(path)) { + return cpp::fail("Gguf file does not exist at " + path.string()); + } + + std::ifstream file(path, std::ios::binary); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + + uint32_t magic_number; + file.read(reinterpret_cast(&magic_number), sizeof(magic_number)); + if (magic_number != GGUF_MAGIC_NUMBER) { + return cpp::fail("Invalid GGUF file: incorrect magic number"); + } + + auto metadata_ptr = std::make_shared(); + + uint32_t version; + file.read(reinterpret_cast(&version), GGUF_VERSION_LENGTH); + metadata_ptr->version = version; + + uint64_t tensor_count; + file.read(reinterpret_cast(&tensor_count), TENSOR_COUNT_LENGTH); + metadata_ptr->tensor_count = tensor_count; + + uint64_t metadata_kv_count; + file.read(reinterpret_cast(&metadata_kv_count), METADATA_KV_COUNT); + metadata_ptr->metadata_kv_count = metadata_kv_count; + + std::unordered_map kv; + for (uint64_t i = 0; i < metadata_kv_count; ++i) { + auto [key_byte_length, key] = ReadString(file); + + char value_type_bytes[4]; + file.read(value_type_bytes, 4); + uint32_t value_type = + static_cast(static_cast(value_type_bytes[0])) | + (static_cast(static_cast(value_type_bytes[1])) + << 8) | + (static_cast(static_cast(value_type_bytes[2])) + << 16) | + (static_cast(static_cast(value_type_bytes[3])) + << 24); + + try { + auto result = ReadMetadataValue(value_type, file, key); + kv.emplace(key, result); + } catch (const std::exception& e) { + CTL_ERR("Error reading metadata value for key '" + key + + "': " + e.what()); + return cpp::fail("Error reading metadata value for key '" + key + "'"); + } + } + + { + metadata_ptr->tokenizer = std::make_shared(); + // initialize tokenizer + if (auto it = kv.find(CHAT_TEMPLATE_ID_KEY); it != kv.end()) { + metadata_ptr->tokenizer->chat_template = + std::get(it->second.value); + } + + for (const auto& key : kSpecialTokenIds) { + if (auto it = kv.find(key); it != kv.end()) { + auto id = std::get(it->second.value); + if (auto token_it = kv.find(TOKEN_LIST_KEY); token_it != kv.end()) { + auto& tokens = std::get>( + token_it->second.value); + + if (key == BOS_ID_KEY) { + metadata_ptr->tokenizer->bos_token = + std::get(tokens[id].value); + } else if (key == EOS_ID_KEY) { + metadata_ptr->tokenizer->eos_token = + std::get(tokens[id].value); + } else if (key == UNK_ID_KEY) { + metadata_ptr->tokenizer->unknown_token = + std::get(tokens[id].value); + } else if (key == PADDING_ID_KEY) { + metadata_ptr->tokenizer->padding_token = + std::get(tokens[id].value); + } else { + CTL_ERR("Unknown special token key: " + key); + } + } + } + } + + if (auto it = kv.find(ADD_BOS_TOKEN_KEY); it != kv.end()) { + metadata_ptr->tokenizer->add_bos_token = std::get(it->second.value); + } + + if (auto it = kv.find(ADD_EOS_TOKEN_KEY); it != kv.end()) { + metadata_ptr->tokenizer->add_eos_token = std::get(it->second.value); + } + } + + CTL_INF("Parsed GGUF metadata successfully: " + metadata_ptr->ToString()); + return metadata_ptr; +} +} // namespace cortex_utils diff --git a/engine/utils/jinja_utils.h b/engine/utils/jinja_utils.h new file mode 100644 index 000000000..f614f4745 --- /dev/null +++ b/engine/utils/jinja_utils.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#include "extensions/remote-engine/template_renderer.h" +#include "utils/chat-template.hpp" +#include "utils/result.hpp" + +namespace jinja { +inline cpp::result RenderTemplate( + std::string& tmpl, const Json::Value& data, const std::string& bos_token, + const std::string& eos_token, bool add_bos_token, bool add_eos_token, + bool add_generation_prompt = true) { + try { + auto converted_json = + remote_engine::TemplateRenderer().ConvertJsonValue(data); + + minja::chat_template chat_tmpl(tmpl, add_bos_token ? bos_token : "", + add_eos_token ? eos_token : ""); + return chat_tmpl.apply(converted_json["messages"], {}, + add_generation_prompt); + } catch (const std::exception& e) { + return cpp::fail("Failed to render template: " + std::string(e.what())); + } +} +} // namespace jinja diff --git a/engine/utils/minja.hpp b/engine/utils/minja.hpp new file mode 100644 index 000000000..76f2110f2 --- /dev/null +++ b/engine/utils/minja.hpp @@ -0,0 +1,3428 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { + public: + using CallableType = + std::function&, ArgumentsValue&)>; + using FilterType = + std::function&, ArgumentsValue&)>; + + private: + using ObjectType = + nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr& array) : array_(array) {} + Value(const std::shared_ptr& object) : object_(object) {} + Value(const std::shared_ptr& callable) + : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json& primitive, std::ostringstream& out, + char string_quote = '\'') { + if (!primitive.is_string()) + throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream& out, int indent = -1, int level = 0, + bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) + out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) + out << ' '; + else + print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) + out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) + print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); + ++it) { + if (it != begin) + print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + + public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t& v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t&) {} + Value(const std::string& v) : primitive_(v) {} + Value(const char* v) : primitive_(std::string(v)) {} + + Value(const json& v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) + throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) + return object_->size(); + if (is_array()) + return array_->size(); + if (is_string()) + return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = + std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType& callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) + throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) + return Value(); + return it->second; + } + return Value(); + } + void set(const Value& key, const Value& value) { + if (!object_) + throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) + throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr& context, + ArgumentsValue& args) const { + if (!callable_) + throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { + return !object_ && !array_ && primitive_.is_null() && !callable_; + } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) + return primitive_.empty(); + if (is_array()) + return array_->empty(); + if (is_object()) + return object_->empty(); + return false; + } + + void for_each(const std::function& callback) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto& item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.get()) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + throw std::runtime_error("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) + return false; + if (is_boolean()) + return get(); + if (is_number()) + return get() != 0; + if (is_string()) + return !get().empty(); + if (is_array()) + return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) + return 0; + if (is_boolean()) + return get() ? 1 : 0; + if (is_number()) + return static_cast(get()); + if (is_string()) { + try { + return std::stol(get()); + } catch (const std::exception&) { + return 0; + } + } + return 0; + } + + bool operator<(const Value& other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) + return get() < other.get(); + if (is_string() && other.is_string()) + return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + + other.dump()); + } + bool operator>=(const Value& other) const { return !(*this < other); } + + bool operator>(const Value& other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) + return get() > other.get(); + if (is_string() && other.is_string()) + return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + + other.dump()); + } + bool operator<=(const Value& other) const { return !(*this > other); } + + bool operator==(const Value& other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) + return false; + } + if (array_) { + if (!other.array_) + return false; + if (array_->size() != other.array_->size()) + return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || + (*array_)[i] != (*other.array_)[i]) + return false; + } + return true; + } else if (object_) { + if (!other.object_) + return false; + if (object_->size() != other.object_->size()) + return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || + item.second != other.object_->at(item.first)) + return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value& other) const { return !(*this == other); } + + bool contains(const char* key) const { return contains(std::string(key)); } + bool contains(const std::string& key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error( + "contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value& value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) + return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) + throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error( + "contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string& key) { + if (object_) + throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value& index) const { + return const_cast(this)->at(index); + } + Value& at(const Value& index) { + if (!index.is_hashable()) + throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) + return array_->at(index.get()); + if (is_object()) + return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) + return array_->at(index); + if (is_object()) + return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string& key, T default_value) const { + if (!contains(key)) + return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) + return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + + dump()); + } + + std::string dump(int indent = -1, bool to_json = false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) + return get(); + if (is_number_integer()) + return std::to_string(get()); + if (is_number_float()) + return std::to_string(get()); + if (is_boolean()) + return get() ? "True" : "False"; + if (is_null()) + return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) + res.push_back(item); + for (const auto& item : *rhs.array_) + res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string& name) { + for (const auto& p : kwargs) { + if (p.first == name) + return true; + } + return false; + } + + Value get_named(const std::string& name) { + for (const auto& [key, value] : kwargs) { + if (key == name) + return value; + } + return Value(); + } + + bool empty() { return args.empty() && kwargs.empty(); } + + void expectArgs(const std::string& method_name, + const std::pair& pos_count, + const std::pair& kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || + kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " + << pos_count.second << " positional arguments and between " + << kw_count.first << " and " << kw_count.second + << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } +}; + +template <> +inline json Value::get() const { + if (is_primitive()) + return primitive_; + if (is_null()) + return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& [key, value] : *object_) { + if (key.is_string()) { + res[key.get()] = value.get(); + } else if (key.is_primitive()) { + res[key.dump()] = value.get(); + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + + dump()); +} + +} // namespace minja + +namespace std { +template <> +struct hash { + size_t operator()(const minja::Value& v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } +}; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string& source, + size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) + out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^" << "\n"; + if (line < max_line) + out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + + public: + Context(Value&& values, const std::shared_ptr& parent = nullptr) + : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) + throw std::runtime_error("Context values must be an object: " + + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make( + Value&& values, const std::shared_ptr& parent = builtins()); + + std::vector keys() { return values_.keys(); } + virtual Value get(const Value& key) { + if (values_.contains(key)) + return values_.at(key); + if (parent_) + return parent_->get(key); + return Value(); + } + virtual Value& at(const Value& key) { + if (values_.contains(key)) + return values_.at(key); + if (parent_) + return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value& key) { + if (values_.contains(key)) + return true; + if (parent_) + return parent_->contains(key); + return false; + } + virtual void set(const Value& key, Value& value) { values_.set(key, value); } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { + protected: + virtual Value do_evaluate(const std::shared_ptr& context) const = 0; + + public: + using Parameters = + std::vector>>; + + Location location; + + Expression(const Location& location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr& context) const { + try { + return do_evaluate(context); + } catch (const std::exception& e) { + std::ostringstream out; + out << e.what(); + if (location.source) + out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; + + public: + VariableExpr(const Location& location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr& context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector& var_names, + const std::shared_ptr& context, + Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error( + "Mismatched number of variables and items in destructuring " + "assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { + public: + enum class Type { + Text, + Expression, + If, + Else, + Elif, + EndIf, + For, + EndFor, + Set, + EndSet, + Comment, + Macro, + EndMacro, + Filter, + EndFilter + }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: + return "text"; + case Type::Expression: + return "expression"; + case Type::If: + return "if"; + case Type::Else: + return "else"; + case Type::Elif: + return "elif"; + case Type::EndIf: + return "endif"; + case Type::For: + return "for"; + case Type::EndFor: + return "endfor"; + case Type::Set: + return "set"; + case Type::EndSet: + return "endset"; + case Type::Comment: + return "comment"; + case Type::Macro: + return "macro"; + case Type::EndMacro: + return "endmacro"; + case Type::Filter: + return "filter"; + case Type::EndFilter: + return "endfilter"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location& location, SpaceHandling pre, + SpaceHandling post) + : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, const std::string& t) + : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, std::shared_ptr&& e) + : TemplateToken(Type::Expression, location, pre, post), + expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, std::shared_ptr&& c) + : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, std::shared_ptr&& c) + : TemplateToken(Type::Elif, location, pre, post), + condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, std::shared_ptr&& n, + Expression::Parameters&& p) + : TemplateToken(Type::Macro, location, pre, post), + name(std::move(n)), + params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, std::shared_ptr&& filter) + : TemplateToken(Type::Filter, location, pre, post), + filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::EndFilter, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, const std::vector& vns, + std::shared_ptr&& iter, + std::shared_ptr&& c, bool r) + : TemplateToken(Type::For, location, pre, post), + var_names(vns), + iterable(std::move(iter)), + condition(std::move(c)), + recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, const std::string& ns, + const std::vector& vns, + std::shared_ptr&& v) + : TemplateToken(Type::Set, location, pre, post), + ns(ns), + var_names(vns), + value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post) + : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location& location, SpaceHandling pre, + SpaceHandling post, const std::string& t) + : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +class TemplateNode { + Location location_; + + protected: + virtual void do_render(std::ostringstream& out, + const std::shared_ptr& context) const = 0; + + public: + TemplateNode(const Location& location) : location_(location) {} + void render(std::ostringstream& out, + const std::shared_ptr& context) const { + try { + do_render(out, context); + } catch (const std::exception& e) { + std::ostringstream err; + err << e.what(); + if (location_.source) + err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location& location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr& context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; + + public: + SequenceNode(const Location& location, + std::vector>&& c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream& out, + const std::shared_ptr& context) const override { + for (const auto& child : children) + child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; + + public: + TextNode(const Location& location, const std::string& t) + : TemplateNode(location), text(t) {} + void do_render(std::ostringstream& out, + const std::shared_ptr&) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; + + public: + ExpressionNode(const Location& location, std::shared_ptr&& e) + : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream& out, + const std::shared_ptr& context) const override { + if (!expr) + throw std::runtime_error("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector< + std::pair, std::shared_ptr>> + cascade; + + public: + IfNode(const Location& location, + std::vector, + std::shared_ptr>>&& c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream& out, + const std::shared_ptr& context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) + throw std::runtime_error("IfNode.cascade.second is null"); + branch.second->render(out, context); + return; + } + } + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; + + public: + ForNode(const Location& location, std::vector&& var_names, + std::shared_ptr&& iterable, + std::shared_ptr&& condition, + std::shared_ptr&& body, bool recursive, + std::shared_ptr&& else_body) + : TemplateNode(location), + var_names(var_names), + iterable(std::move(iterable)), + condition(std::move(condition)), + body(std::move(body)), + recursive(recursive), + else_body(std::move(else_body)) {} + + void do_render(std::ostringstream& out, + const std::shared_ptr& context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) + throw std::runtime_error("ForNode.iterable is null"); + if (!body) + throw std::runtime_error("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + throw std::runtime_error("For loop iterable must be iterable: " + + iterable_value.dump()); + } + iterable_value.for_each([&](Value& item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = + recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t)filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr&, + ArgumentsValue& args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error( + "cycle() expects at least 1 positional argument and " + "no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto& item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t)i + 1); + loop.set("index0", (int64_t)i); + loop.set("revindex", (int64_t)(n - i)); + loop.set("revindex0", (int64_t)(n - i - 1)); + loop.set("length", (int64_t)n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + body->render(out, loop_context); + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr&, + ArgumentsValue& args) { + if (args.args.size() != 1 || !args.kwargs.empty() || + !args.args[0].is_array()) { + throw std::runtime_error( + "loop() expects exactly 1 positional iterable argument"); + } + auto& items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; + + public: + MacroNode(const Location& location, std::shared_ptr&& n, + Expression::Parameters&& p, std::shared_ptr&& b) + : TemplateNode(location), + name(std::move(n)), + params(std::move(p)), + body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto& name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream&, + const std::shared_ptr& macro_context) const override { + if (!name) + throw std::runtime_error("MacroNode.name is null"); + if (!body) + throw std::runtime_error("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr& context, + ArgumentsValue& args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto& arg = args.args[i]; + if (i >= params.size()) + throw std::runtime_error("Too many positional arguments for macro " + + name->get_name()); + param_set[i] = true; + auto& param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto& [arg_name, value] : args.kwargs) { + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) + throw std::runtime_error("Unknown parameter name for macro " + + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + + public: + FilterNode(const Location& location, std::shared_ptr&& f, + std::shared_ptr&& b) + : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + + void do_render(std::ostringstream& out, + const std::shared_ptr& context) const override { + if (!filter) + throw std::runtime_error("FilterNode.filter is null"); + if (!body) + throw std::runtime_error("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + throw std::runtime_error("Filter must be a callable: " + + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; + + public: + SetNode(const Location& location, const std::string& ns, + const std::vector& vns, std::shared_ptr&& v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + void do_render(std::ostringstream&, + const std::shared_ptr& context) const override { + if (!value) + throw std::runtime_error("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error( + "Namespaced set only supports a single variable name"); + } + auto& name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) + throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; + + public: + SetTemplateNode(const Location& location, const std::string& name, + std::shared_ptr&& tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream&, + const std::shared_ptr& context) const override { + if (!template_value) + throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value{template_value->render(context)}; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; + + public: + IfExpr(const Location& location, std::shared_ptr&& c, + std::shared_ptr&& t, std::shared_ptr&& e) + : Expression(location), + condition(std::move(c)), + then_expr(std::move(t)), + else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!condition) + throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) + throw std::runtime_error("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; + + public: + LiteralExpr(const Location& location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr&) const override { + return value; + } +}; + +class ArrayExpr : public Expression { + std::vector> elements; + + public: + ArrayExpr(const Location& location, + std::vector>&& e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) + throw std::runtime_error("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector< + std::pair, std::shared_ptr>> + elements; + + public: + DictExpr(const Location& location, + std::vector, + std::shared_ptr>>&& e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + auto result = Value::object(); + for (const auto& [key, value] : elements) { + if (!key) + throw std::runtime_error("Dict key is null"); + if (!value) + throw std::runtime_error("Dict value is null"); + result.set(key->evaluate(context), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { + public: + std::shared_ptr start, end; + SliceExpr(const Location& location, std::shared_ptr&& s, + std::shared_ptr&& e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr&) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; + + public: + SubscriptExpr(const Location& location, std::shared_ptr&& b, + std::shared_ptr&& i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!base) + throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) + throw std::runtime_error("SubscriptExpr.index is null"); + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + auto start = + slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() + : (int64_t)target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) + start = s.size() + start; + if (end < 0) + end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) + start = target_value.size() + start; + if (end < 0) + end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error( + target_value.is_null() + ? "Cannot subscript null" + : "Subscripting only supported on arrays and strings"); + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error( + "'" + t->get_name() + "' is " + + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { + public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location& location, std::shared_ptr&& e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!expr) + throw std::runtime_error("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: + return e; + case Op::Minus: + return -e; + case Op::LogicalNot: + return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + throw std::runtime_error( + "Expansion operator is only supported in function calls and " + "collections"); + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { + public: + enum class Op { + StrConcat, + Add, + Sub, + Mul, + MulMul, + Div, + DivDiv, + Mod, + Eq, + Ne, + Lt, + Gt, + Le, + Ge, + And, + Or, + In, + NotIn, + Is, + IsNot + }; + + private: + std::shared_ptr left; + std::shared_ptr right; + Op op; + + public: + BinaryOpExpr(const Location& location, std::shared_ptr&& l, + std::shared_ptr&& r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!left) + throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) + throw std::runtime_error("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value& l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) + throw std::runtime_error( + "Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto& name = t->get_name(); + if (name == "none") + return l.is_null(); + if (name == "boolean") + return l.is_boolean(); + if (name == "integer") + return l.is_number_integer(); + if (name == "float") + return l.is_number_float(); + if (name == "number") + return l.is_number(); + if (name == "string") + return l.is_string(); + if (name == "mapping") + return l.is_object(); + if (name == "iterable") + return l.is_iterable(); + if (name == "sequence") + return l.is_array(); + if (name == "defined") + return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) + return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) + return Value(true); + return right->evaluate(context).to_bool(); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: + return l.to_str() + r.to_str(); + case Op::Add: + return l + r; + case Op::Sub: + return l - r; + case Op::Mul: + return l * r; + case Op::Div: + return l / r; + case Op::MulMul: + return std::pow(l.get(), r.get()); + case Op::DivDiv: + return l.get() / r.get(); + case Op::Mod: + return l.get() % r.get(); + case Op::Eq: + return l == r; + case Op::Ne: + return l != r; + case Op::Lt: + return l < r; + case Op::Gt: + return l > r; + case Op::Le: + return l <= r; + case Op::Ge: + return l >= r; + case Op::In: + return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: + return !(r.is_array() && r.contains(l)); + default: + break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable( + [l, do_eval](const std::shared_ptr& context, + ArgumentsValue& args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr& context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (auto un_expr = std::dynamic_pointer_cast(arg)) { + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + throw std::runtime_error( + "Expansion operator only supported on arrays"); + } + array.for_each([&](Value& value) { vargs.args.push_back(value); }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + throw std::runtime_error( + "ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value& key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& [name, value] : this->kwargs) { + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string& s) { + static std::regex trailing_spaces_regex("^\\s+|\\s+$"); + return std::regex_replace(s, trailing_spaces_regex, ""); +} + +static std::string html_escape(const std::string& s) { + std::string result; + result.reserve(s.size()); + for (const auto& c : s) { + switch (c) { + case '&': + result += "&"; + break; + case '<': + result += "<"; + break; + case '>': + result += ">"; + break; + case '"': + result += """; + break; + case '\'': + result += "'"; + break; + default: + result += c; + break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; + + public: + MethodCallExpr(const Location& location, std::shared_ptr&& obj, + std::shared_ptr&& m, ArgumentsExpression&& a) + : Expression(location), + object(std::move(obj)), + method(std::move(m)), + args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!object) + throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) + throw std::runtime_error("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + + "' on null"); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t)obj.size()) + throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && + std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) + res[i] = std::toupper(res[i]); + else + res[i] = std::tolower(res[i]); + } + return res; + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { + public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location& location, std::shared_ptr&& obj, + ArgumentsExpression&& a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + if (!object) + throw std::runtime_error("CallExpr.object is null"); + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; + + public: + FilterExpr(const Location& location, + std::vector>&& p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr& context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) + throw std::runtime_error("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr&& e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { + private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, + const Options& options) + : template_str(template_str), options(options) { + if (!template_str) + throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) + ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) + return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': + result += '\n'; + break; + case 'r': + result += '\r'; + break; + case 't': + result += '\t'; + break; + case 'b': + result += '\b'; + break; + case 'f': + result += '\f'; + break; + case '\\': + result += '\\'; + break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return std::make_unique(std::move(result)); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) + return nullptr; + if (*it == '"') + return doParse('"'); + if (*it == '\'') + return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) + ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) + throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) + throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) + return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) + return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") + return std::make_shared(true); + if (token == "false" || token == "False") + return std::make_shared(false); + if (token == "None") + return std::make_shared(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) + return std::make_shared(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + + public: + expression_parsing_error(const std::string& message, const CharIterator& it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator& begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector& symbols) const { + for (const auto& symbol : symbols) { + if (std::distance(it, end) >= (int64_t)symbol.size() && + std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups( + const std::regex& regex, + SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken( + const std::regex& regex, + SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken( + const std::string& token, + SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t)token.size() && + std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) + return left; + + if (!allow_if_expr) + return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto [condition, else_expr] = parseIfExpression(); + return std::make_shared(location, std::move(condition), + std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t)std::distance(start, it)}; + } + + std::pair, std::shared_ptr> + parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) + throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) + throw std::runtime_error("Expected 'else' expression"); + } + return std::pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) + throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) + throw std::runtime_error("Expected right side of 'or' expression"); + left = std::make_shared( + location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) + throw std::runtime_error("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), + UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) + throw std::runtime_error( + "Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) + throw std::runtime_error("Expected right side of 'and' expression"); + left = std::make_shared( + location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) + throw std::runtime_error( + "Expected left side of 'logical compare' expression"); + + static std::regex compare_tok( + R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) + throw std::runtime_error("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) + throw std::runtime_error( + "Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") + op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") + op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") + op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") + op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") + op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") + op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") + op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") + op = BinaryOpExpr::Op::NotIn; + else + throw std::runtime_error("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), + std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) + throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) + throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) + throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) + throw std::runtime_error("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) + throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) + throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) + throw std::runtime_error( + "Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) + throw std::runtime_error( + "Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), + std::move(right), + BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) + throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) + throw std::runtime_error( + "Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), + std::move(right), + BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) + throw std::runtime_error( + "Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) + throw std::runtime_error( + "Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), + std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) + throw std::runtime_error( + "Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) + throw std::runtime_error( + "Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), + std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string& name, + ArgumentsExpression&& args) const { + return std::make_shared( + get_location(), std::make_shared(get_location(), name), + std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) + throw std::runtime_error( + "Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) + return expr; + if (!expr) + throw std::runtime_error("Expected expr of 'expansion' expression"); + return std::make_shared(get_location(), std::move(expr), + op_str == "*" + ? UnaryOpExpr::Op::Expansion + : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) + return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) + return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) + return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) + return braced; + + auto array = parseArray(); + if (array) + return array; + + auto dictionary = parseDictionary(); + if (dictionary) + return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({"[", "."})) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = std::make_shared(slice_end->location, nullptr, + std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({"]"})) { + index = std::make_shared( + slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = std::make_shared(slice_start->location, + std::move(slice_start), + std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) + throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) + throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared( + value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) + throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({"("})) { + auto callParams = parseCallArgs(); + value = std::make_shared( + identifier->location, std::move(value), std::move(identifier), + std::move(callParams)); + } else { + auto key = std::make_shared( + identifier->location, Value(identifier->get_name())); + value = std::make_shared( + identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({"("})) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), + std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) + return nullptr; + + auto expr = parseExpression(); + if (!expr) + throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) + throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) + throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) + return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) + throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) + throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) + return nullptr; + + std::vector< + std::pair, std::shared_ptr>> + elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) + throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) + throw std::runtime_error( + "Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) + throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error( + "Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex( + R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) + throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken& token) const { + return std::runtime_error( + "Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken& token) const { + return std::runtime_error( + "Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_keyword_tok( + R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex text_regex(R"([\s\S\n\r]*?($|(?=\{\{|\{%|\{#)))"); + static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); + static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)) + .empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_unique( + location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, + SpaceHandling::Keep)) + .empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, + SpaceHandling::Keep)) + .empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) + throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) + throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) + throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) + throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) + throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) + throw std::runtime_error("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(varnames), + std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex( + R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) + throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) + throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) + throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, ns, var_names, + std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) + throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(macroname), + std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) + throw std::runtime_error("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique( + location, pre_space, post_space)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (!(text = consumeToken(text_regex, SpaceHandling::Keep)) + .empty()) { + tokens.push_back(std::make_unique( + location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + if (it != end) + throw std::runtime_error("Unexpected character"); + } + } + return tokens; + } catch (const std::exception& e) { + throw std::runtime_error( + e.what() + + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator& begin, TemplateTokenIterator& it, + const TemplateTokenIterator& end, bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto& token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, + std::shared_ptr>> + cascade; + cascade.emplace_back(std::move(if_token->condition), + parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), + parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back( + std::make_shared(token->location, std::move(cascade))); + } else if (auto for_token = + dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared( + token->location, std::move(for_token->var_names), + std::move(for_token->iterable), std::move(for_token->condition), + std::move(body), for_token->recursive, std::move(else_body))); + } else if (auto text_token = + dynamic_cast(token.get())) { + SpaceHandling pre_space = + (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = + it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && + !dynamic_cast((*(it - 2)).get())) { + static std::regex leading_line(R"(^[ \t]*\r?\n)"); + text = std::regex_replace(text, leading_line, ""); + } + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + static std::regex trailing_last_line_space_regex(R"((\r?\n)[ \t]*$)"); + text = std::regex_replace(text, trailing_last_line_space_regex, "$1"); + } + + if (it == end && !options.keep_trailing_newline) { + static std::regex r(R"(\r?\n$)"); + text = std::regex_replace(text, r, ""); // Strip one trailing newline + } + children.emplace_back( + std::make_shared(token->location, text)); + } else if (auto expr_token = + dynamic_cast(token.get())) { + children.emplace_back(std::make_shared( + token->location, std::move(expr_token->expr))); + } else if (auto set_token = + dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(std::make_shared( + token->location, set_token->ns, set_token->var_names, + std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + if (!set_token->ns.empty()) + throw std::runtime_error( + "Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) + throw std::runtime_error( + "Structural assignment not supported in set with template " + "value"); + auto& name = set_token->var_names[0]; + children.emplace_back(std::make_shared( + token->location, name, std::move(value_template))); + } + } else if (auto macro_token = + dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared( + token->location, std::move(macro_token->name), + std::move(macro_token->params), std::move(body))); + } else if (auto filter_token = + dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared( + token->location, std::move(filter_token->filter), std::move(body))); + } else if (dynamic_cast(token.get())) { + // Ignore comments + } else if (dynamic_cast(token.get()) || + dynamic_cast(token.get()) || + dynamic_cast(token.get()) || + dynamic_cast(token.get()) || + dynamic_cast(token.get()) || + dynamic_cast(token.get()) || + dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it - 1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return std::make_shared(Location{template_str, 0}, + std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), + std::move(children)); + } + } + + public: + static std::shared_ptr parse(const std::string& template_str, + const Options& options) { + Parser parser(std::make_shared(template_str), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function( + const std::string& fn_name, const std::vector& params, + const std::function&, Value& args)>& + fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) + named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr& context, + ArgumentsValue& args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto& arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (auto& [name, value] : args.kwargs) { + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + name + " for function " + + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set( + "raise_exception", + simple_function( + "raise_exception", {"message"}, + [](const std::shared_ptr&, Value& args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", + simple_function("tojson", {"value", "indent"}, + [](const std::shared_ptr&, Value& args) { + return Value(args.at("value").dump( + args.get("indent", -1), + /* tojson= */ true)); + })); + globals.set("items", + simple_function( + "items", {"object"}, + [](const std::shared_ptr&, Value& args) { + auto items = Value::array(); + if (args.contains("object")) { + auto& obj = args.at("object"); + if (obj.is_string()) { + auto json_obj = json::parse(obj.get()); + for (const auto& kv : json_obj.items()) { + items.push_back(Value::array({kv.key(), kv.value()})); + } + } else if (!obj.is_null()) { + for (auto& key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function( + "last", {"items"}, + [](const std::shared_ptr&, Value& args) { + auto items = args.at("items"); + if (!items.is_array()) + throw std::runtime_error("object is not a list"); + if (items.size() == 0) + return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function( + "trim", {"text"}, + [](const std::shared_ptr&, Value& args) { + auto& text = args.at("text"); + return text.is_null() + ? text + : Value(strip(text.get())); + })); + globals.set("lower", simple_function( + "lower", {"text"}, + [](const std::shared_ptr&, Value& args) { + auto text = args.at("text"); + if (text.is_null()) + return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), + std::back_inserter(res), ::tolower); + return Value(res); + })); + globals.set("default", Value::callable([=](const std::shared_ptr&, + ArgumentsValue& args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto& value = args.args[0]; + auto& default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) + : value.is_null() ? default_value + : value; + })); + auto escape = simple_function( + "escape", {"text"}, [](const std::shared_ptr&, Value& args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set( + "joiner", + simple_function( + "joiner", {"sep"}, [](const std::shared_ptr&, Value& args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, + [sep, first](const std::shared_ptr&, + const Value&) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", + simple_function("count", {"items"}, + [](const std::shared_ptr&, Value& args) { + return Value((int64_t)args.at("items").size()); + })); + globals.set( + "dictsort", + simple_function("dictsort", {"value"}, + [](const std::shared_ptr&, Value& args) { + if (args.size() != 1) + throw std::runtime_error( + "dictsort expects exactly 1 argument (TODO: fix " + "implementation)"); + auto& value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto& key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set( + "join", + simple_function( + "join", {"items", "d"}, + [](const std::shared_ptr&, Value& args) { + auto do_join = [](Value& items, const std::string& sep) { + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) + first = false; + else + oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto& items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function( + "", {"items"}, + [sep, do_join](const std::shared_ptr&, Value& args) { + auto& items = args.at("items"); + if (!items.to_bool() || !items.is_array()) + throw std::runtime_error( + "join expects an array for items, got: " + + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr&, + ArgumentsValue& args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, + {0, std::numeric_limits::max()}); + for (auto& [name, value] : args.kwargs) { + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function( + "equalto", {"expected", "actual"}, + [](const std::shared_ptr&, Value& args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", {"items"}, + [](const std::shared_ptr&, + Value& args) -> Value { + auto& items = args.at("items"); + return (int64_t)items.size(); + })); + globals.set("safe", simple_function("safe", {"value"}, + [](const std::shared_ptr&, + Value& args) -> Value { + return args.at("value"); + })); + globals.set("string", simple_function("string", {"value"}, + [](const std::shared_ptr&, + Value& args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", {"value"}, + [](const std::shared_ptr&, + Value& args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", + simple_function( + "list", {"items"}, + [](const std::shared_ptr&, Value& args) -> Value { + auto& items = args.at("items"); + if (!items.is_array()) + throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", + simple_function( + "unique", {"items"}, + [](const std::shared_ptr&, Value& args) -> Value { + auto& items = args.at("items"); + if (!items.is_array()) + throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value& filter, Value& extra_args) -> Value { + return simple_function( + "", {"value"}, + [=](const std::shared_ptr& context, Value& args) { + auto& value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject + globals.set( + "reject", Value::callable([=](const std::shared_ptr& context, + ArgumentsValue& args) { + args.expectArgs("reject", {2, std::numeric_limits::max()}, + {0, 0}); + auto& items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) + throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto& item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (!pred_res.to_bool()) { + res.push_back(item); + } + } + return res; + })); + globals.set( + "map", Value::callable([=](const std::shared_ptr& context, + ArgumentsValue& args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || + (args.has_named("default") && args.kwargs.size() == 2))) { + auto& items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto& item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) + throw std::runtime_error("Undefined filter: " + + args.args[1].dump()); + ArgumentsValue filter_args{{Value()}, {}}; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto& item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", + simple_function("indent", {"text", "indent", "first"}, + [](const std::shared_ptr&, Value& args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent( + args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) + is_first = false; + else + out += "\n"; + if (needs_indent) + out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') + out += "\n"; + return out; + })); + globals.set( + "selectattr", Value::callable([=](const std::shared_ptr& context, + ArgumentsValue& args) { + args.expectArgs("selectattr", {2, std::numeric_limits::max()}, + {0, 0}); + auto& items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args{{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) + throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto& item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool()) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + })); + globals.set("range", Value::callable([=](const std::shared_ptr&, + ArgumentsValue& args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto& arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto& [name, value] : args.kwargs) { + size_t i; + if (name == "start") + i = 0; + else if (name == "end") + i = 1; + else if (name == "step") + i = 2; + else + throw std::runtime_error("Unknown argument " + name + + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error( + "Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make( + Value&& values, const std::shared_ptr& parent) { + return std::make_shared( + values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja From 1a73c0cffc0eecd3bf0b51f2b009575e442eadbe Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 26 Dec 2024 09:25:29 +0700 Subject: [PATCH 03/16] fix: forward start model parameters (#1825) Co-authored-by: vansangpfiev --- engine/controllers/models.cc | 53 ++++++-------------- engine/services/model_service.cc | 35 +++++++------ engine/services/model_service.h | 19 +------ engine/test/components/test_json_helper.cc | 58 ++++++++++++++++++++++ engine/utils/json_helper.h | 24 +++++++++ 5 files changed, 116 insertions(+), 73 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 59793b2a6..1c33ab1dc 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -488,55 +488,31 @@ void Models::StartModel( if (!http_util::HasFieldInReq(req, callback, "model")) return; auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); - StartParameterOverride params_override; - if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) { - params_override.custom_prompt_template = o.asString(); - } - - if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) { - params_override.cache_enabled = o.asBool(); - } - - if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) { - params_override.ngl = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) { - params_override.n_parallel = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) { - params_override.ctx_len = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) { - params_override.cache_type = o.asString(); - } + std::optional mmproj; if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) { - params_override.mmproj = o.asString(); + mmproj = o.asString(); } + auto bypass_llama_model_path = false; // Support both llama_model_path and model_path for backward compatible // model_path has higher priority if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) { - params_override.model_path = o.asString(); + auto model_path = o.asString(); if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) { // Bypass if model does not exist in DB and llama_model_path exists - if (std::filesystem::exists(params_override.model_path.value()) && + if (std::filesystem::exists(model_path) && !model_service_->HasModel(model_handle)) { CTL_INF("llama_model_path exists, bypass check model id"); - params_override.bypass_llama_model_path = true; + bypass_llama_model_path = true; } } } - if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) { - params_override.model_path = o.asString(); - } + auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path); auto model_entry = model_service_->GetDownloadedModel(model_handle); - if (!model_entry.has_value() && !params_override.bypass_model_check()) { + if (!model_entry.has_value() && !bypass_model_check) { Json::Value ret; ret["message"] = "Cannot find model: " + model_handle; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -544,9 +520,8 @@ void Models::StartModel( callback(resp); return; } - std::string engine_name = params_override.bypass_model_check() - ? kLlamaEngine - : model_entry.value().engine; + std::string engine_name = + bypass_model_check ? kLlamaEngine : model_entry.value().engine; auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -565,7 +540,9 @@ void Models::StartModel( return; } - auto result = model_service_->StartModel(model_handle, params_override); + auto result = model_service_->StartModel( + model_handle, *(req->getJsonObject()) /*params_override*/, + bypass_model_check); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); @@ -668,7 +645,7 @@ void Models::AddRemoteModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); - + auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -687,7 +664,7 @@ void Models::AddRemoteModel( callback(resp); return; } - + config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); cortex::db::Models modellist_utils_obj; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 0d909b61f..be0eb12a7 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -749,19 +749,28 @@ cpp::result ModelService::DeleteModel( } cpp::result ModelService::StartModel( - const std::string& model_handle, - const StartParameterOverride& params_override) { + const std::string& model_handle, const Json::Value& params_override, + bool bypass_model_check) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; + std::optional custom_prompt_template; + std::optional ctx_len; + if (auto& o = params_override["prompt_template"]; !o.isNull()) { + custom_prompt_template = o.asString(); + } + + if (auto& o = params_override["ctx_len"]; !o.isNull()) { + ctx_len = o.asInt(); + } try { constexpr const int kDefautlContextLength = 8192; int max_model_context_length = kDefautlContextLength; Json::Value json_data; // Currently we don't support download vision models, so we need to bypass check - if (!params_override.bypass_model_check()) { + if (!bypass_model_check) { auto model_entry = modellist_handler.GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); @@ -839,29 +848,19 @@ cpp::result ModelService::StartModel( } json_data["model"] = model_handle; - if (auto& cpt = params_override.custom_prompt_template; - !cpt.value_or("").empty()) { + if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) { auto parse_prompt_result = string_utils::ParsePrompt(cpt.value()); json_data["system_prompt"] = parse_prompt_result.system_prompt; json_data["user_prompt"] = parse_prompt_result.user_prompt; json_data["ai_prompt"] = parse_prompt_result.ai_prompt; } -#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \ - if (param_override.param_name) { \ - json_obj[#param_name] = param_override.param_name.value(); \ - } + json_helper::MergeJson(json_data, params_override); - ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled); - ASSIGN_IF_PRESENT(json_data, params_override, ngl); - ASSIGN_IF_PRESENT(json_data, params_override, n_parallel); - ASSIGN_IF_PRESENT(json_data, params_override, cache_type); - ASSIGN_IF_PRESENT(json_data, params_override, mmproj); - ASSIGN_IF_PRESENT(json_data, params_override, model_path); -#undef ASSIGN_IF_PRESENT - if (params_override.ctx_len) { + // Set the latest ctx_len + if (ctx_len) { json_data["ctx_len"] = - std::min(params_override.ctx_len.value(), max_model_context_length); + std::min(ctx_len.value(), max_model_context_length); } CTL_INF(json_data.toStyledString()); auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(), diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 8b24b3421..ab3596812 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -22,21 +22,6 @@ struct ModelPullInfo { std::string download_url; }; -struct StartParameterOverride { - std::optional cache_enabled; - std::optional ngl; - std::optional n_parallel; - std::optional ctx_len; - std::optional custom_prompt_template; - std::optional cache_type; - std::optional mmproj; - std::optional model_path; - bool bypass_llama_model_path = false; - bool bypass_model_check() const { - return mmproj.has_value() || bypass_llama_model_path; - } -}; - struct StartModelResult { bool success; std::optional warning; @@ -82,8 +67,8 @@ class ModelService { cpp::result DeleteModel(const std::string& model_handle); cpp::result StartModel( - const std::string& model_handle, - const StartParameterOverride& params_override); + const std::string& model_handle, const Json::Value& params_override, + bool bypass_model_check); cpp::result StopModel(const std::string& model_handle); diff --git a/engine/test/components/test_json_helper.cc b/engine/test/components/test_json_helper.cc index cb3f4683a..ba5e27165 100644 --- a/engine/test/components/test_json_helper.cc +++ b/engine/test/components/test_json_helper.cc @@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) { EXPECT_TRUE(result.isNull()); } + +TEST(MergeJsonTest, MergeSimpleObjects) { + Json::Value json1, json2; + json1["name"] = "John"; + json1["age"] = 30; + + json2["age"] = 31; + json2["email"] = "john@example.com"; + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["name"] = "John"; + expected["age"] = 31; + expected["email"] = "john@example.com"; + + EXPECT_EQ(json1, expected); +} + +TEST(MergeJsonTest, MergeNestedObjects) { + Json::Value json1, json2; + json1["person"]["name"] = "John"; + json1["person"]["age"] = 30; + + json2["person"]["age"] = 31; + json2["person"]["email"] = "john@example.com"; + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["person"]["name"] = "John"; + expected["person"]["age"] = 31; + expected["person"]["email"] = "john@example.com"; + + EXPECT_EQ(json1, expected); +} + +TEST(MergeJsonTest, MergeArrays) { + Json::Value json1, json2; + json1["hobbies"] = Json::Value(Json::arrayValue); + json1["hobbies"].append("reading"); + json1["hobbies"].append("painting"); + + json2["hobbies"] = Json::Value(Json::arrayValue); + json2["hobbies"].append("hiking"); + json2["hobbies"].append("painting"); + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["hobbies"] = Json::Value(Json::arrayValue); + expected["hobbies"].append("reading"); + expected["hobbies"].append("painting"); + expected["hobbies"].append("hiking"); + expected["hobbies"].append("painting"); + + EXPECT_EQ(json1, expected); +} diff --git a/engine/utils/json_helper.h b/engine/utils/json_helper.h index 82f994751..3b08651c4 100644 --- a/engine/utils/json_helper.h +++ b/engine/utils/json_helper.h @@ -16,4 +16,28 @@ inline std::string DumpJsonString(const Json::Value& json) { builder["indentation"] = ""; return Json::writeString(builder, json); } + +inline void MergeJson(Json::Value& target, const Json::Value& source) { + for (const auto& member : source.getMemberNames()) { + if (target.isMember(member)) { + // If the member exists in both objects, recursively merge the values + if (target[member].type() == Json::objectValue && + source[member].type() == Json::objectValue) { + MergeJson(target[member], source[member]); + } else if (target[member].type() == Json::arrayValue && + source[member].type() == Json::arrayValue) { + // If the member is an array in both objects, merge the arrays + for (const auto& value : source[member]) { + target[member].append(value); + } + } else { + // Otherwise, overwrite the value in the target with the value from the source + target[member] = source[member]; + } + } else { + // If the member doesn't exist in the target, add it + target[member] = source[member]; + } + } +} } // namespace json_helper From 3456c7b83974285feb71a9848bed3ab59aa36ef1 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 27 Dec 2024 08:27:44 +0700 Subject: [PATCH 04/16] fix: not create new folder if is registering paths (#1828) --- engine/services/engine_service.cc | 7 ++++--- engine/utils/file_manager_utils.cc | 5 +++-- engine/utils/file_manager_utils.h | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 2ca06cb33..93311f98b 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -132,8 +132,8 @@ cpp::result EngineService::UnzipEngine( CTL_INF("Found cuda variant, extract it"); found_cuda = true; // extract binary - auto cuda_path = - file_manager_utils::GetCudaToolkitPath(NormalizeEngine(engine)); + auto cuda_path = file_manager_utils::GetCudaToolkitPath( + NormalizeEngine(engine), true); archive_utils::ExtractArchive(path + "/" + cf, cuda_path.string(), true); } @@ -434,7 +434,8 @@ cpp::result EngineService::DownloadCuda( }}; auto on_finished = [engine](const DownloadTask& finishedTask) { - auto engine_path = file_manager_utils::GetCudaToolkitPath(engine); + auto engine_path = file_manager_utils::GetCudaToolkitPath(engine, true); + archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), engine_path.string()); try { diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index 338abadac..aee65020c 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -289,13 +289,14 @@ std::filesystem::path GetModelsContainerPath() { return models_container_path; } -std::filesystem::path GetCudaToolkitPath(const std::string& engine) { +std::filesystem::path GetCudaToolkitPath(const std::string& engine, + bool create_if_not_exist) { auto engine_path = getenv("ENGINE_PATH") ? std::filesystem::path(getenv("ENGINE_PATH")) : GetCortexDataPath(); auto cuda_path = engine_path / "engines" / engine / "deps"; - if (!std::filesystem::exists(cuda_path)) { + if (create_if_not_exist && !std::filesystem::exists(cuda_path)) { std::filesystem::create_directories(cuda_path); } diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index 91102d002..059fe6ae3 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -45,7 +45,8 @@ void CreateDirectoryRecursively(const std::string& path); std::filesystem::path GetModelsContainerPath(); -std::filesystem::path GetCudaToolkitPath(const std::string& engine); +std::filesystem::path GetCudaToolkitPath(const std::string& engine, + bool create_if_not_exist = false); std::filesystem::path GetEnginesContainerPath(); From f94527fd43e576ca096596f94ab2e7005ad82267 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 27 Dec 2024 21:49:19 +0700 Subject: [PATCH 05/16] feat: add openai assistant (#1826) --- docs/static/openapi/cortex.json | 542 ++++++++++++++++-- engine/common/assistant.h | 271 ++++++++- .../common/assistant_code_interpreter_tool.h | 32 ++ engine/common/assistant_file_search_tool.h | 151 +++++ engine/common/assistant_function_tool.h | 130 +++++ engine/common/assistant_tool.h | 88 +-- engine/common/dto/assistant_create_dto.h | 211 +++++++ engine/common/dto/assistant_update_dto.h | 201 +++++++ engine/common/dto/base_dto.h | 16 + engine/common/message_attachment.h | 15 +- .../common/repository/assistant_repository.h | 25 + engine/common/thread.h | 12 +- engine/common/thread_tool_resources.h | 50 -- engine/common/tool_resources.h | 114 ++++ engine/controllers/assistants.cc | 185 +++++- engine/controllers/assistants.h | 39 ++ engine/main.cc | 6 +- .../repositories/assistant_fs_repository.cc | 214 +++++++ engine/repositories/assistant_fs_repository.h | 59 ++ engine/repositories/file_fs_repository.h | 2 +- engine/repositories/message_fs_repository.h | 2 +- engine/services/assistant_service.cc | 180 ++++++ engine/services/assistant_service.h | 32 +- engine/services/thread_service.cc | 4 +- engine/services/thread_service.h | 5 +- engine/test/components/test_assistant.cc | 194 +++++++ .../test_assistant_tool_code_interpreter.cc | 49 ++ .../test_assistant_tool_file_search.cc | 207 +++++++ .../test_assistant_tool_function.cc | 240 ++++++++ engine/test/components/test_tool_resources.cc | 212 +++++++ 30 files changed, 3289 insertions(+), 199 deletions(-) create mode 100644 engine/common/assistant_code_interpreter_tool.h create mode 100644 engine/common/assistant_file_search_tool.h create mode 100644 engine/common/assistant_function_tool.h create mode 100644 engine/common/dto/assistant_create_dto.h create mode 100644 engine/common/dto/assistant_update_dto.h create mode 100644 engine/common/dto/base_dto.h create mode 100644 engine/common/repository/assistant_repository.h delete mode 100644 engine/common/thread_tool_resources.h create mode 100644 engine/common/tool_resources.h create mode 100644 engine/repositories/assistant_fs_repository.cc create mode 100644 engine/repositories/assistant_fs_repository.h create mode 100644 engine/test/components/test_assistant.cc create mode 100644 engine/test/components/test_assistant_tool_code_interpreter.cc create mode 100644 engine/test/components/test_assistant_tool_file_search.cc create mode 100644 engine/test/components/test_assistant_tool_function.cc create mode 100644 engine/test/components/test_tool_resources.cc diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 479e300ce..d006f0f2d 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5,77 +5,470 @@ "post": { "operationId": "AssistantsController_create", "summary": "Create assistant", - "description": "Creates a new assistant.", - "parameters": [], + "description": "Creates a new assistant with the specified configuration.", "requestBody": { "required": true, "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAssistantDto" + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": ["model"] } } } }, "responses": { - "201": { - "description": "The assistant has been successfully created." + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } } }, "tags": ["Assistants"] }, - "get": { - "operationId": "AssistantsController_findAll", - "summary": "List assistants", - "description": "Returns a list of assistants.", + "patch": { + "operationId": "AssistantsController_update", + "summary": "Update assistant", + "description": "Updates an assistant. Requires at least one modifiable field.", "parameters": [ { - "name": "limit", - "required": false, - "in": "query", - "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", - "schema": { - "type": "number" - } - }, - { - "name": "order", - "required": false, - "in": "query", - "description": "Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.", - "schema": { - "type": "string" - } - }, - { - "name": "after", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.", + "name": "id", + "required": true, + "in": "path", + "description": "The unique identifier of the assistant.", "schema": { "type": "string" } }, { - "name": "before", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.", + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", "schema": { - "type": "string" + "type": "string", + "enum": ["assistants=v2"] } } ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "minProperties": 1 + } + } + } + }, "responses": { "200": { "description": "Ok", "content": { "application/json": { "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AssistantEntity" - } + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } + } + }, + "tags": ["Assistants"] + }, + "get": { + "operationId": "AssistantsController_list", + "summary": "List assistants", + "description": "Returns a list of assistants.", + "responses": { + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "object": { + "type": "string", + "enum": ["list"], + "description": "The object type, which is always 'list' for a list response." + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + }, + "required": ["object", "data"] } } } @@ -88,7 +481,7 @@ "get": { "operationId": "AssistantsController_findOne", "summary": "Get assistant", - "description": "Retrieves a specific assistant defined by an assistant's `id`.", + "description": "Retrieves a specific assistant by ID.", "parameters": [ { "name": "id", @@ -98,6 +491,16 @@ "schema": { "type": "string" } + }, + { + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", + "schema": { + "type": "string", + "enum": ["assistants=v2"] + } } ], "responses": { @@ -106,7 +509,38 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AssistantEntity" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] } } } @@ -117,7 +551,7 @@ "delete": { "operationId": "AssistantsController_remove", "summary": "Delete assistant", - "description": "Deletes a specific assistant defined by an assistant's `id`.", + "description": "Deletes a specific assistant by ID.", "parameters": [ { "name": "id", @@ -131,11 +565,28 @@ ], "responses": { "200": { - "description": "The assistant has been successfully deleted.", + "description": "Ok", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteAssistantResponseDto" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the deleted assistant." + }, + "object": { + "type": "string", + "enum": ["assistant.deleted"], + "description": "The object type for a deleted assistant." + }, + "deleted": { + "type": "boolean", + "enum": [true], + "description": "Indicates the assistant was successfully deleted." + } + }, + "required": ["id", "object", "deleted"] } } } @@ -3456,6 +3907,7 @@ "Files", "Hardware", "Events", + "Assistants", "Threads", "Messages", "Pulling Models", diff --git a/engine/common/assistant.h b/engine/common/assistant.h index e49147e9e..6210a0c2c 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -1,9 +1,13 @@ #pragma once #include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" #include "common/assistant_tool.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" namespace OpenAi { @@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable { } }; -struct Assistant { +struct Assistant : JsonSerializable { + Assistant() = default; + + ~Assistant() = default; + + Assistant(const Assistant&) = delete; + + Assistant& operator=(const Assistant&) = delete; + + Assistant(Assistant&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + name{std::move(other.name)}, + description{std::move(other.description)}, + model(std::move(other.model)), + instructions(std::move(other.instructions)), + tools(std::move(other.tools)), + tool_resources(std::move(other.tool_resources)), + metadata(std::move(other.metadata)), + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + Assistant& operator=(Assistant&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + created_at = other.created_at; + name = std::move(other.name); + description = std::move(other.description); + model = std::move(other.model); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + /** * The identifier, which can be referenced in API endpoints. */ @@ -126,8 +172,7 @@ struct Assistant { * requires a list of file IDs, while the file_search tool requires a list * of vector store IDs. */ - std::optional> - tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. This can be @@ -153,5 +198,223 @@ struct Assistant { * We generally recommend altering this or temperature but not both. */ std::optional top_p; + + std::variant response_format; + + cpp::result ToJson() override { + try { + Json::Value root; + + root["id"] = std::move(id); + root["object"] = "assistant"; + root["created_at"] = created_at; + if (name.has_value()) { + root["name"] = name.value(); + } + if (description.has_value()) { + root["description"] = description.value(); + } + root["model"] = model; + if (instructions.has_value()) { + root["instructions"] = instructions.value(); + } + + Json::Value tools_jarr{Json::arrayValue}; + for (auto& tool_ptr : tools) { + if (auto it = tool_ptr->ToJson(); it.has_value()) { + tools_jarr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + root["tools"] = tools_jarr; + if (tool_resources) { + Json::Value tool_resources_json{Json::objectValue}; + + if (auto* code_interpreter = + dynamic_cast(tool_resources.get())) { + auto result = code_interpreter->ToJson(); + if (result.has_value()) { + tool_resources_json["code_interpreter"] = result.value(); + } else { + CTL_WRN("Failed to convert code_interpreter to json: " + + result.error()); + } + } else if (auto* file_search = dynamic_cast( + tool_resources.get())) { + auto result = file_search->ToJson(); + if (result.has_value()) { + tool_resources_json["file_search"] = result.value(); + } else { + CTL_WRN("Failed to convert file_search to json: " + result.error()); + } + } + + // Only add tool_resources to root if we successfully serialized some resources + if (!tool_resources_json.empty()) { + root["tool_resources"] = tool_resources_json; + } + } + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + root["metadata"] = metadata_json; + + if (temperature.has_value()) { + root["temperature"] = temperature.value(); + } + if (top_p.has_value()) { + root["top_p"] = top_p.value(); + } + return root; + } catch (const std::exception& e) { + return cpp::fail("ToJson failed: " + std::string(e.what())); + } + } + + static cpp::result FromJson(Json::Value&& json) { + try { + Assistant assistant; + + // Parse required fields + if (!json.isMember("id") || !json["id"].isString()) { + return cpp::fail("Missing or invalid 'id' field"); + } + assistant.id = json["id"].asString(); + + if (!json.isMember("object") || !json["object"].isString() || + json["object"].asString() != "assistant") { + return cpp::fail("Missing or invalid 'object' field"); + } + + if (!json.isMember("created_at") || !json["created_at"].isUInt64()) { + return cpp::fail("Missing or invalid 'created_at' field"); + } + assistant.created_at = json["created_at"].asUInt64(); + + if (!json.isMember("model") || !json["model"].isString()) { + return cpp::fail("Missing or invalid 'model' field"); + } + assistant.model = json["model"].asString(); + + // Parse optional fields + if (json.isMember("name") && json["name"].isString()) { + assistant.name = json["name"].asString(); + } + + if (json.isMember("description") && json["description"].isString()) { + assistant.description = json["description"].asString(); + } + + if (json.isMember("instructions") && json["instructions"].isString()) { + assistant.instructions = json["instructions"].asString(); + } + + // Parse tools array + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + + if (json.isMember("tool_resources") && + json["tool_resources"].isObject()) { + const auto& tool_resources_json = json["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + assistant.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + if (json.isMember("temperature") && json["temperature"].isDouble()) { + assistant.temperature = json["temperature"].asFloat(); + } + + if (json.isMember("top_p") && json["top_p"].isDouble()) { + assistant.top_p = json["top_p"].asFloat(); + } + + return assistant; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } }; } // namespace OpenAi diff --git a/engine/common/assistant_code_interpreter_tool.h b/engine/common/assistant_code_interpreter_tool.h new file mode 100644 index 000000000..43bfac47c --- /dev/null +++ b/engine/common/assistant_code_interpreter_tool.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/assistant_tool.h" + +namespace OpenAi { +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {} + + AssistantCodeInterpreterTool(const AssistantCodeInterpreterTool&) = delete; + + AssistantCodeInterpreterTool& operator=(const AssistantCodeInterpreterTool&) = + delete; + + AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default; + + AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) = + default; + + ~AssistantCodeInterpreterTool() = default; + + static cpp::result FromJson() { + AssistantCodeInterpreterTool tool; + return std::move(tool); + } + + cpp::result ToJson() override { + Json::Value json; + json["type"] = type; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/assistant_file_search_tool.h b/engine/common/assistant_file_search_tool.h new file mode 100644 index 000000000..2abaa7f6e --- /dev/null +++ b/engine/common/assistant_file_search_tool.h @@ -0,0 +1,151 @@ +#pragma once + +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct FileSearchRankingOption : public JsonSerializable { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + + FileSearchRankingOption(float score_threshold, + const std::string& ranker = "auto") + : ranker{ranker}, score_threshold{score_threshold} {} + + FileSearchRankingOption(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption& operator=(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption(FileSearchRankingOption&&) = default; + + FileSearchRankingOption& operator=(FileSearchRankingOption&&) = default; + + ~FileSearchRankingOption() = default; + + static cpp::result FromJson( + const Json::Value& json) { + if (!json.isMember("score_threshold")) { + return cpp::fail("score_threshold must be provided"); + } + + FileSearchRankingOption option{ + json["score_threshold"].asFloat(), + std::move(json.get("ranker", "auto").asString())}; + return option; + } + + cpp::result ToJson() override { + Json::Value json; + json["ranker"] = ranker; + json["score_threshold"] = score_threshold; + return json; + } +}; + +/** + * Overrides for the file search tool. + */ +struct AssistantFileSearch : public JsonSerializable { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_results; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + FileSearchRankingOption ranking_options; + + AssistantFileSearch(int max_num_results, + FileSearchRankingOption&& ranking_options) + : max_num_results{max_num_results}, + ranking_options{std::move(ranking_options)} {} + + AssistantFileSearch(const AssistantFileSearch&) = delete; + + AssistantFileSearch& operator=(const AssistantFileSearch&) = delete; + + AssistantFileSearch(AssistantFileSearch&&) = default; + + AssistantFileSearch& operator=(AssistantFileSearch&&) = default; + + ~AssistantFileSearch() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{ + json["max_num_results"].asInt(), + FileSearchRankingOption::FromJson(json["ranking_options"]).value()}; + return search; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + Json::Value root; + root["max_num_results"] = max_num_results; + root["ranking_options"] = ranking_options.ToJson().value(); + return root; + } +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearch file_search; + + AssistantFileSearchTool(AssistantFileSearch& file_search) + : AssistantTool("file_search"), file_search{std::move(file_search)} {} + + AssistantFileSearchTool(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool& operator=(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool(AssistantFileSearchTool&&) = default; + + AssistantFileSearchTool& operator=(AssistantFileSearchTool&&) = default; + + ~AssistantFileSearchTool() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{json["file_search"]["max_num_results"].asInt(), + FileSearchRankingOption::FromJson( + json["file_search"]["ranking_options"]) + .value()}; + AssistantFileSearchTool tool{search}; + return tool; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value root; + root["type"] = type; + root["file_search"] = file_search.ToJson().value(); + return root; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_function_tool.h b/engine/common/assistant_function_tool.h new file mode 100644 index 000000000..7998cb8ff --- /dev/null +++ b/engine/common/assistant_function_tool.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct AssistantFunction : public JsonSerializable { + AssistantFunction(const std::string& description, const std::string& name, + const Json::Value& parameters, + const std::optional& strict) + : description{std::move(description)}, + name{std::move(name)}, + parameters{std::move(parameters)}, + strict{strict} {} + + AssistantFunction(const AssistantFunction&) = delete; + + AssistantFunction& operator=(const AssistantFunction&) = delete; + + AssistantFunction(AssistantFunction&&) = default; + + AssistantFunction& operator=(AssistantFunction&&) = default; + + ~AssistantFunction() = default; + + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + /** + * The parameters the functions accepts, described as a JSON Schema object. + * See the guide for examples, and the JSON Schema reference for documentation + * about the format. + * + * Omitting parameters defines a function with an empty parameter list. + */ + Json::Value parameters; + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + + static cpp::result FromJson( + const Json::Value& json) { + if (json.empty()) { + return cpp::fail("Function json can't be empty"); + } + + if (!json.isMember("name") || json.get("name", "").asString().empty()) { + return cpp::fail("Function name can't be empty"); + } + + if (!json.isMember("description")) { + return cpp::fail("Function description is mandatory"); + } + + if (!json.isMember("parameters")) { + return cpp::fail("Function parameters are mandatory"); + } + + std::optional is_strict = std::nullopt; + if (json.isMember("strict")) { + is_strict = json["strict"].asBool(); + } + AssistantFunction function{json["description"].asString(), + json["name"].asString(), json["parameters"], + is_strict}; + function.parameters = json["parameters"]; + return function; + } + + cpp::result ToJson() override { + Json::Value json; + json["description"] = description; + json["name"] = name; + if (strict.has_value()) { + json["strict"] = *strict; + } + json["parameters"] = parameters; + return json; + } +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool(AssistantFunction& function) + : AssistantTool("function"), function{std::move(function)} {} + + AssistantFunctionTool(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool& operator=(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool(AssistantFunctionTool&&) = default; + + AssistantFunctionTool& operator=(AssistantFunctionTool&&) = default; + + ~AssistantFunctionTool() = default; + + AssistantFunction function; + + static cpp::result FromJson( + const Json::Value& json) { + auto function_res = AssistantFunction::FromJson(json["function"]); + if (function_res.has_error()) { + return cpp::fail("Failed to parse function: " + function_res.error()); + } + return AssistantFunctionTool{function_res.value()}; + } + + cpp::result ToJson() override { + Json::Value root; + root["type"] = type; + root["function"] = function.ToJson().value(); + return root; + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h index 622721708..d02392392 100644 --- a/engine/common/assistant_tool.h +++ b/engine/common/assistant_tool.h @@ -1,91 +1,27 @@ #pragma once -#include #include +#include "common/json_serializable.h" namespace OpenAi { -struct AssistantTool { +struct AssistantTool : public JsonSerializable { std::string type; AssistantTool(const std::string& type) : type{type} {} - virtual ~AssistantTool() = default; -}; - -struct AssistantCodeInterpreterTool : public AssistantTool { - AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} - - ~AssistantCodeInterpreterTool() = default; -}; - -struct AssistantFileSearchTool : public AssistantTool { - AssistantFileSearchTool() : AssistantTool("file_search") {} - - ~AssistantFileSearchTool() = default; + AssistantTool(const AssistantTool&) = delete; - /** - * The ranking options for the file search. If not specified, - * the file search tool will use the auto ranker and a score_threshold of 0. - * - * See the file search tool documentation for more information. - */ - struct RankingOption { - /** - * The ranker to use for the file search. If not specified will use the auto ranker. - */ - std::string ranker; + AssistantTool& operator=(const AssistantTool&) = delete; - /** - * The score threshold for the file search. All values must be a - * floating point number between 0 and 1. - */ - float score_threshold; - }; + AssistantTool(AssistantTool&& other) noexcept : type{std::move(other.type)} {} - /** - * Overrides for the file search tool. - */ - struct FileSearch { - /** - * The maximum number of results the file search tool should output. - * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. - * This number should be between 1 and 50 inclusive. - * - * Note that the file search tool may output fewer than max_num_results results. - * See the file search tool documentation for more information. - */ - int max_num_result; - }; -}; - -struct AssistantFunctionTool : public AssistantTool { - AssistantFunctionTool() : AssistantTool("function") {} - - ~AssistantFunctionTool() = default; - - struct Function { - /** - * A description of what the function does, used by the model to choose - * when and how to call the function. - */ - std::string description; + AssistantTool& operator=(AssistantTool&& other) noexcept { + if (this != &other) { + type = std::move(other.type); + } + return *this; + } - /** - * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain - * underscores and dashes, with a maximum length of 64. - */ - std::string name; - - // TODO: namh handle parameters - - /** - * Whether to enable strict schema adherence when generating the function call. - * If set to true, the model will follow the exact schema defined in the parameters - * field. Only a subset of JSON Schema is supported when strict is true. - * - * Learn more about Structured Outputs in the function calling guide. - */ - std::optional strict; - }; + virtual ~AssistantTool() = default; }; } // namespace OpenAi diff --git a/engine/common/dto/assistant_create_dto.h b/engine/common/dto/assistant_create_dto.h new file mode 100644 index 000000000..19d79b833 --- /dev/null +++ b/engine/common/dto/assistant_create_dto.h @@ -0,0 +1,211 @@ +#pragma once + +#include +#include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/assistant_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct CreateAssistantDto : public BaseDto { + CreateAssistantDto() = default; + + ~CreateAssistantDto() = default; + + CreateAssistantDto(const CreateAssistantDto&) = delete; + + CreateAssistantDto& operator=(const CreateAssistantDto&) = delete; + + CreateAssistantDto(CreateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + CreateAssistantDto& operator=(CreateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + + std::string model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (model.empty()) { + return cpp::fail("Model is mandatory"); + } + + if (response_format.has_value()) { + const auto& variant_value = response_format.value(); + if (std::holds_alternative(variant_value)) { + if (std::get(variant_value) != "auto") { + return cpp::fail("Invalid response_format"); + } + } + } + + return {}; + } + + static CreateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + CreateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instructions")) { + dto.instructions = std::move(root["instructions"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + } +}; +} // namespace dto diff --git a/engine/common/dto/assistant_update_dto.h b/engine/common/dto/assistant_update_dto.h new file mode 100644 index 000000000..01e5844d7 --- /dev/null +++ b/engine/common/dto/assistant_update_dto.h @@ -0,0 +1,201 @@ +#pragma once + +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct UpdateAssistantDto : public BaseDto { + UpdateAssistantDto() = default; + + ~UpdateAssistantDto() = default; + + UpdateAssistantDto(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto& operator=(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto(UpdateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + UpdateAssistantDto& operator=(UpdateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + std::optional model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (!model.has_value() && !name.has_value() && !description.has_value() && + !instructions.has_value() && !metadata.has_value() && + !temperature.has_value() && !top_p.has_value() && + !response_format.has_value()) { + return cpp::fail("At least one field must be provided"); + } + + return {}; + } + + static UpdateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + UpdateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instruction")) { + dto.instructions = std::move(root["instruction"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + }; +}; +} // namespace dto diff --git a/engine/common/dto/base_dto.h b/engine/common/dto/base_dto.h new file mode 100644 index 000000000..ed7460aa3 --- /dev/null +++ b/engine/common/dto/base_dto.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "utils/result.hpp" + +namespace dto { +template +struct BaseDto { + virtual ~BaseDto() = default; + + /** + * Validate itself. + */ + virtual cpp::result Validate() const = 0; +}; +} // namespace dto diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index 767ec9bea..6a0fb02e9 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -4,22 +4,27 @@ #include "common/json_serializable.h" namespace OpenAi { - // The tools to add this file to. struct Tool { std::string type; Tool(const std::string& type) : type{type} {} + + virtual ~Tool() = default; }; // The type of tool being defined: code_interpreter -struct CodeInterpreter : Tool { - CodeInterpreter() : Tool{"code_interpreter"} {} +struct MessageCodeInterpreter : Tool { + MessageCodeInterpreter() : Tool{"code_interpreter"} {} + + ~MessageCodeInterpreter() = default; }; // The type of tool being defined: file_search -struct FileSearch : Tool { - FileSearch() : Tool{"file_search"} {} +struct MessageFileSearch : Tool { + MessageFileSearch() : Tool{"file_search"} {} + + ~MessageFileSearch() = default; }; // A list of files attached to the message, and the tools they were added to. diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h new file mode 100644 index 000000000..d0ff1908d --- /dev/null +++ b/engine/common/repository/assistant_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/assistant.h" +#include "utils/result.hpp" + +class AssistantRepository { + public: + virtual cpp::result, std::string> + ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, const std::string& before) const = 0; + + virtual cpp::result CreateAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result RetrieveAssistant( + const std::string assistant_id) const = 0; + + virtual cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result DeleteAssistant( + const std::string& assitant_id) = 0; + + virtual ~AssistantRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h index 2bd5d866b..dc57ba32d 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -4,7 +4,7 @@ #include #include #include "common/assistant.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" #include "utils/logging_utils.h" @@ -36,7 +36,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::unique_ptr tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -65,7 +65,7 @@ struct Thread : JsonSerializable { const auto& tool_json = json["tool_resources"]; if (tool_json.isMember("code_interpreter")) { - auto code_interpreter = std::make_unique(); + auto code_interpreter = std::make_unique(); const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; if (file_ids.isArray()) { for (const auto& file_id : file_ids) { @@ -74,7 +74,7 @@ struct Thread : JsonSerializable { } thread.tool_resources = std::move(code_interpreter); } else if (tool_json.isMember("file_search")) { - auto file_search = std::make_unique(); + auto file_search = std::make_unique(); const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; if (store_ids.isArray()) { for (const auto& store_id : store_ids) { @@ -148,10 +148,10 @@ struct Thread : JsonSerializable { Json::Value tool_json; if (auto code_interpreter = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["code_interpreter"] = tool_result.value(); } else if (auto file_search = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["file_search"] = tool_result.value(); } json["tool_resources"] = tool_json; diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h deleted file mode 100644 index 3c22a4480..000000000 --- a/engine/common/thread_tool_resources.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include -#include -#include "common/json_serializable.h" - -namespace OpenAi { - -struct ThreadToolResources : JsonSerializable { - ~ThreadToolResources() = default; - - virtual cpp::result ToJson() override = 0; -}; - -struct ThreadCodeInterpreter : ThreadToolResources { - std::vector file_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value file_ids_json{Json::arrayValue}; - for (auto& file_id : file_ids) { - file_ids_json.append(file_id); - } - json["file_ids"] = file_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; - -struct ThreadFileSearch : ThreadToolResources { - std::vector vector_store_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value vector_store_ids_json{Json::arrayValue}; - for (auto& vector_store_id : vector_store_ids) { - vector_store_ids_json.append(vector_store_id); - } - json["vector_store_ids"] = vector_store_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; -} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h new file mode 100644 index 000000000..5aadb3f8b --- /dev/null +++ b/engine/common/tool_resources.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ToolResources : JsonSerializable { + ToolResources() = default; + + ToolResources(const ToolResources&) = delete; + + ToolResources& operator=(const ToolResources&) = delete; + + ToolResources(ToolResources&&) noexcept = default; + + ToolResources& operator=(ToolResources&&) noexcept = default; + + virtual ~ToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct CodeInterpreter : ToolResources { + CodeInterpreter() = default; + + ~CodeInterpreter() override = default; + + CodeInterpreter(const CodeInterpreter&) = delete; + + CodeInterpreter& operator=(const CodeInterpreter&) = delete; + + CodeInterpreter(CodeInterpreter&& other) noexcept + : ToolResources(std::move(other)), file_ids(std::move(other.file_ids)) {} + + CodeInterpreter& operator=(CodeInterpreter&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + file_ids = std::move(other.file_ids); + } + return *this; + } + + std::vector file_ids; + + static cpp::result FromJson( + const Json::Value& json) { + CodeInterpreter code_interpreter; + if (json.isMember("file_ids")) { + for (const auto& file_id : json["file_ids"]) { + code_interpreter.file_ids.push_back(file_id.asString()); + } + } + return code_interpreter; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } +}; + +struct FileSearch : ToolResources { + FileSearch() = default; + + ~FileSearch() override = default; + + FileSearch(const FileSearch&) = delete; + + FileSearch& operator=(const FileSearch&) = delete; + + FileSearch(FileSearch&& other) noexcept + : ToolResources(std::move(other)), + vector_store_ids{std::move(other.vector_store_ids)} {} + + FileSearch& operator=(FileSearch&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + + vector_store_ids = std::move(other.vector_store_ids); + } + return *this; + } + + std::vector vector_store_ids; + + static cpp::result FromJson( + const Json::Value& json) { + FileSearch file_search; + if (json.isMember("vector_store_ids")) { + for (const auto& vector_store_id : json["vector_store_ids"]) { + file_search.vector_store_ids.push_back(vector_store_id.asString()); + } + } + return file_search; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc index 405d7ed3c..530e180a5 100644 --- a/engine/controllers/assistants.cc +++ b/engine/controllers/assistants.cc @@ -1,4 +1,6 @@ #include "assistants.h" +#include "common/api-dto/delete_success_response.h" +#include "common/dto/assistant_create_dto.h" #include "utils/cortex_utils.h" #include "utils/logging_utils.h" @@ -6,7 +8,12 @@ void Assistants::RetrieveAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const { - CTL_INF("RetrieveAssistant: " + assistant_id); + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return RetrieveAssistantV2(req, std::move(callback), assistant_id); + } + auto res = assistant_service_->RetrieveAssistant(assistant_id); if (res.has_error()) { Json::Value ret; @@ -33,6 +40,78 @@ void Assistants::RetrieveAssistant( } } +void Assistants::RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + auto res = assistant_service_->RetrieveAssistantV2(assistant_id); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::CreateAssistantDto::FromJson(std::move(*json_body)); + CTL_INF("CreateAssistantV2: " << dto.model); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->CreateAssistantV2(dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto to_json_res = res->ToJson(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(to_json_res.value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::CreateAssistant( const HttpRequestPtr& req, std::function&& callback, @@ -88,10 +167,55 @@ void Assistants::CreateAssistant( callback(resp); } +void Assistants::ModifyAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::UpdateAssistantDto::FromJson(std::move(*json_body)); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->ModifyAssistantV2(assistant_id, dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::ModifyAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) { + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return ModifyAssistantV2(req, std::move(callback), assistant_id); + } auto json_body = req->getJsonObject(); if (json_body == nullptr) { Json::Value ret; @@ -142,3 +266,62 @@ void Assistants::ModifyAssistant( resp->setStatusCode(k200OK); callback(resp); } + +void Assistants::ListAssistants( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + + auto res = assistant_service_->ListAssistants( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + Json::Value assistant_list(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + assistant_list.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = assistant_list; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Assistants::DeleteAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto res = assistant_service_->DeleteAssistantV2(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = assistant_id; + response.object = "assistant.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h index 94ddd14b1..30111bb01 100644 --- a/engine/controllers/assistants.h +++ b/engine/controllers/assistants.h @@ -7,33 +7,72 @@ using namespace drogon; class Assistants : public drogon::HttpController { + constexpr static auto kOpenAiAssistantKeyV2 = "openai-beta"; + constexpr static auto kOpenAiAssistantValueV2 = "assistants=v2"; + public: METHOD_LIST_BEGIN + ADD_METHOD_TO( + Assistants::ListAssistants, + "/v1/" + "assistants?limit={limit}&order={order}&after={after}&before={before}", + Get); + + ADD_METHOD_TO(Assistants::DeleteAssistant, "/v1/assistants/{assistant_id}", + Options, Delete); + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", Get); ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", Options, Post); + ADD_METHOD_TO(Assistants::CreateAssistantV2, "/v1/assistants", Options, Post); + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", Options, Patch); + METHOD_LIST_END explicit Assistants(std::shared_ptr assistant_srv) : assistant_service_{assistant_srv} {}; + void ListAssistants(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + void RetrieveAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const; + void RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void DeleteAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + void CreateAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback); + void ModifyAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void ModifyAssistantV2(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + private: std::shared_ptr assistant_service_; }; diff --git a/engine/main.cc b/engine/main.cc index ddf1eefd8..938392bf0 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -15,6 +15,7 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" @@ -142,9 +143,12 @@ void RunServer(std::optional host, std::optional port, auto file_repo = std::make_shared(data_folder_path); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); + auto assistant_repo = + std::make_shared(data_folder_path); auto file_srv = std::make_shared(file_repo); - auto assistant_srv = std::make_shared(thread_repo); + auto assistant_srv = + std::make_shared(thread_repo, assistant_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc new file mode 100644 index 000000000..87b4174fd --- /dev/null +++ b/engine/repositories/assistant_fs_repository.cc @@ -0,0 +1,214 @@ +#include "assistant_fs_repository.h" +#include +#include +#include +#include +#include "utils/result.hpp" + +cpp::result, std::string> +AssistantFsRepository::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + std::vector assistants; + try { + auto assistant_container_path = + data_folder_path_ / kAssistantContainerFolderName; + std::vector all_assistants; + + for (const auto& entry : + std::filesystem::directory_iterator(assistant_container_path)) { + if (!entry.is_directory()) { + continue; + } + + auto assistant_file = entry.path() / kAssistantFileName; + if (!std::filesystem::exists(assistant_file)) { + continue; + } + + auto current_assistant_id = entry.path().filename().string(); + + if (!after.empty() && current_assistant_id <= after) { + continue; + } + + if (!before.empty() && current_assistant_id >= before) { + continue; + } + + std::shared_lock assistant_lock(GrabAssistantMutex(current_assistant_id)); + auto assistant_res = LoadAssistant(current_assistant_id); + if (assistant_res.has_value()) { + all_assistants.push_back(std::move(assistant_res.value())); + } + assistant_lock.unlock(); + } + + // sorting + if (order == "desc") { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at > assistant2.created_at; + }); + } else { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at < assistant2.created_at; + }); + } + + size_t assistant_count = + std::min(static_cast(limit), all_assistants.size()); + for (size_t i = 0; i < assistant_count; i++) { + assistants.push_back(std::move(all_assistants[i])); + } + + return assistants; + } catch (const std::exception& e) { + return cpp::fail("Failed to list assistants: " + std::string(e.what())); + } +} + +cpp::result +AssistantFsRepository::RetrieveAssistant(const std::string assistant_id) const { + std::shared_lock lock(GrabAssistantMutex(assistant_id)); + return LoadAssistant(assistant_id); +} + +cpp::result AssistantFsRepository::ModifyAssistant( + OpenAi::Assistant& assistant) { + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (!std::filesystem::exists(path)) { + lock.unlock(); + return cpp::fail("Assistant doesn't exist: " + assistant.id); + } + } + + return SaveAssistant(assistant); +} + +cpp::result AssistantFsRepository::DeleteAssistant( + const std::string& assitant_id) { + { + std::unique_lock assistant_lock(GrabAssistantMutex(assitant_id)); + auto path = GetAssistantPath(assitant_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Assistant doesn't exist: " + assitant_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::fail(""); + } + } + + std::unique_lock map_lock(map_mutex_); + assistant_mutexes_.erase(assitant_id); + return {}; +} + +cpp::result +AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { + CTL_INF("CreateAssistant: " + assistant.id); + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (std::filesystem::exists(path)) { + return cpp::fail("Assistant already exists: " + assistant.id); + } + + std::filesystem::create_directories(path); + auto assistant_file_path = path / kAssistantFileName; + std::ofstream assistant_file(assistant_file_path); + assistant_file.close(); + + CTL_INF("CreateAssistant created new file: " + assistant.id); + auto save_result = SaveAssistant(assistant); + if (save_result.has_error()) { + lock.unlock(); + return cpp::fail("Failed to save assistant: " + save_result.error()); + } + } + return RetrieveAssistant(assistant.id); +} + +cpp::result AssistantFsRepository::SaveAssistant( + OpenAi::Assistant& assistant) { + auto path = GetAssistantPath(assistant.id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + + std::ofstream file(path); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + try { + file << assistant.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save assistant: " + std::string(e.what())); + } +} + +std::filesystem::path AssistantFsRepository::GetAssistantPath( + const std::string& assistant_id) const { + auto container_folder_path = + data_folder_path_ / kAssistantContainerFolderName; + if (!std::filesystem::exists(container_folder_path)) { + std::filesystem::create_directories(container_folder_path); + } + + return data_folder_path_ / kAssistantContainerFolderName / assistant_id; +} + +std::shared_mutex& AssistantFsRepository::GrabAssistantMutex( + const std::string& assistant_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = assistant_mutexes_.find(assistant_id); + if (it != assistant_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *assistant_mutexes_ + .try_emplace(assistant_id, std::make_unique()) + .first->second; +} + +cpp::result +AssistantFsRepository::LoadAssistant(const std::string& assistant_id) const { + auto path = GetAssistantPath(assistant_id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Assistant::FromJson(std::move(root)); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h new file mode 100644 index 000000000..f310bd54e --- /dev/null +++ b/engine/repositories/assistant_fs_repository.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include "common/repository/assistant_repository.h" + +class AssistantFsRepository : public AssistantRepository { + public: + constexpr static auto kAssistantContainerFolderName = "assistants"; + constexpr static auto kAssistantFileName = "assistant.json"; + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result CreateAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result RetrieveAssistant( + const std::string assistant_id) const override; + + cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result DeleteAssistant( + const std::string& assitant_id) override; + + explicit AssistantFsRepository(const std::filesystem::path& data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing AssistantFsRepository.."); + auto path = data_folder_path_ / kAssistantContainerFolderName; + + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + } + + ~AssistantFsRepository() = default; + + private: + std::filesystem::path GetAssistantPath(const std::string& assistant_id) const; + + std::shared_mutex& GrabAssistantMutex(const std::string& assistant_id) const; + + cpp::result SaveAssistant(OpenAi::Assistant& assistant); + + cpp::result LoadAssistant( + const std::string& assistant_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + assistant_mutexes_; +}; diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 974e81fa4..77af60dfc 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -28,7 +28,7 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(std::filesystem::path data_folder_path) + explicit FileFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 2146778bf..0ca6e89b3 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -32,7 +32,7 @@ class MessageFsRepository : public MessageRepository { const std::string& thread_id, std::optional> messages) override; - explicit MessageFsRepository(std::filesystem::path data_folder_path) + explicit MessageFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing MessageFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index e769bf23f..08a5a743f 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -1,5 +1,7 @@ #include "assistant_service.h" +#include #include "utils/logging_utils.h" +#include "utils/ulid_generator.h" cpp::result AssistantService::CreateAssistant(const std::string& thread_id, @@ -26,3 +28,181 @@ AssistantService::ModifyAssistant(const std::string& thread_id, CTL_INF("RetrieveAssistant: " + thread_id); return thread_repository_->ModifyAssistant(thread_id, assistant); } + +cpp::result, std::string> +AssistantService::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("List assistants invoked"); + return assistant_repository_->ListAssistants(limit, order, after, before); +} + +cpp::result AssistantService::CreateAssistantV2( + const dto::CreateAssistantDto& create_dto) { + + OpenAi::Assistant assistant; + assistant.id = "asst_" + ulid::GenerateUlid(); + assistant.model = create_dto.model; + if (create_dto.name) { + assistant.name = *create_dto.name; + } + if (create_dto.description) { + assistant.description = *create_dto.description; + } + if (create_dto.instructions) { + assistant.instructions = *create_dto.instructions; + } + if (create_dto.metadata) { + assistant.metadata = *create_dto.metadata; + } + if (create_dto.temperature) { + assistant.temperature = *create_dto.temperature; + } + if (create_dto.top_p) { + assistant.top_p = *create_dto.top_p; + } + for (auto& tool_ptr : create_dto.tools) { + // Create a new unique_ptr in assistant.tools that takes ownership + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + assistant.tools.push_back(std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (create_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (create_dto.response_format) { + assistant.response_format = *create_dto.response_format; + } + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + assistant.created_at = seconds_since_epoch; + return assistant_repository_->CreateAssistant(assistant); +} +cpp::result +AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->RetrieveAssistant(assistant_id); +} + +cpp::result AssistantService::ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + if (!update_dto.Validate()) { + return cpp::fail("Invalid update assistant dto"); + } + + // First retrieve the existing assistant + auto existing_assistant = + assistant_repository_->RetrieveAssistant(assistant_id); + if (existing_assistant.has_error()) { + return cpp::fail(existing_assistant.error()); + } + + OpenAi::Assistant updated_assistant; + updated_assistant.id = assistant_id; + + // Update fields if they are present in the DTO + if (update_dto.model) { + updated_assistant.model = *update_dto.model; + } + if (update_dto.name) { + updated_assistant.name = *update_dto.name; + } + if (update_dto.description) { + updated_assistant.description = *update_dto.description; + } + if (update_dto.instructions) { + updated_assistant.instructions = *update_dto.instructions; + } + if (update_dto.metadata) { + updated_assistant.metadata = *update_dto.metadata; + } + if (update_dto.temperature) { + updated_assistant.temperature = *update_dto.temperature; + } + if (update_dto.top_p) { + updated_assistant.top_p = *update_dto.top_p; + } + for (auto& tool_ptr : update_dto.tools) { + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (update_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (update_dto.response_format) { + updated_assistant.response_format = *update_dto.response_format; + } + + auto res = assistant_repository_->ModifyAssistant(updated_assistant); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return updated_assistant; +} + +cpp::result AssistantService::DeleteAssistantV2( + const std::string& assistant_id) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->DeleteAssistant(assistant_id); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h index e7f7414d1..ad31104ff 100644 --- a/engine/services/assistant_service.h +++ b/engine/services/assistant_service.h @@ -1,15 +1,14 @@ #pragma once #include "common/assistant.h" +#include "common/dto/assistant_create_dto.h" +#include "common/dto/assistant_update_dto.h" +#include "common/repository/assistant_repository.h" #include "repositories/thread_fs_repository.h" #include "utils/result.hpp" class AssistantService { public: - explicit AssistantService( - std::shared_ptr thread_repository) - : thread_repository_{thread_repository} {} - cpp::result CreateAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); @@ -19,6 +18,31 @@ class AssistantService { cpp::result ModifyAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); + // V2 + cpp::result CreateAssistantV2( + const dto::CreateAssistantDto& create_dto); + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveAssistantV2( + const std::string& assistant_id) const; + + cpp::result ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto); + + cpp::result DeleteAssistantV2( + const std::string& assistant_id); + + explicit AssistantService( + std::shared_ptr thread_repository, + std::shared_ptr assistant_repository) + : thread_repository_{thread_repository}, + assistant_repository_{assistant_repository} {} + private: std::shared_ptr thread_repository_; + std::shared_ptr assistant_repository_; }; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 0ec0ac89d..9c5e7e857 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -4,7 +4,7 @@ #include "utils/ulid_generator.h" cpp::result ThreadService::CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -46,7 +46,7 @@ cpp::result ThreadService::RetrieveThread( cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 966b0ab01..7011f46f3 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -2,7 +2,6 @@ #include #include "common/repository/thread_repository.h" -#include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "utils/result.hpp" @@ -12,7 +11,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,7 +23,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread( diff --git a/engine/test/components/test_assistant.cc b/engine/test/components/test_assistant.cc new file mode 100644 index 000000000..20ba08f34 --- /dev/null +++ b/engine/test/components/test_assistant.cc @@ -0,0 +1,194 @@ +#include +#include "common/assistant.h" + +namespace OpenAi { +namespace { + +class AssistantTest : public ::testing::Test { + protected: + void SetUp() override { + // Set up base assistant with minimal required fields + base_assistant.id = "asst_123"; + base_assistant.object = "assistant"; + base_assistant.created_at = 1702000000; + base_assistant.model = "gpt-4"; + } + + Assistant base_assistant; +}; + +TEST_F(AssistantTest, MinimalAssistantToJson) { + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["id"].asString(), "asst_123"); + EXPECT_EQ(json["object"].asString(), "assistant"); + EXPECT_EQ(json["created_at"].asUInt64(), 1702000000); + EXPECT_EQ(json["model"].asString(), "gpt-4"); +} + +TEST_F(AssistantTest, FullAssistantToJson) { + base_assistant.name = "Test Assistant"; + base_assistant.description = "Test Description"; + base_assistant.instructions = "Test Instructions"; + base_assistant.temperature = 0.7f; + base_assistant.top_p = 0.9f; + + // Add a code interpreter tool + auto code_tool = std::make_unique(); + base_assistant.tools.push_back(std::move(code_tool)); + + // Add metadata + base_assistant.metadata["key1"] = std::string("value1"); + base_assistant.metadata["key2"] = true; + base_assistant.metadata["key3"] = static_cast(42ULL); + + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["name"].asString(), "Test Assistant"); + EXPECT_EQ(json["description"].asString(), "Test Description"); + EXPECT_EQ(json["instructions"].asString(), "Test Instructions"); + EXPECT_FLOAT_EQ(json["temperature"].asFloat(), 0.7f); + EXPECT_FLOAT_EQ(json["top_p"].asFloat(), 0.9f); + + EXPECT_TRUE(json["tools"].isArray()); + EXPECT_EQ(json["tools"].size(), 1); + EXPECT_EQ(json["tools"][0]["type"].asString(), "code_interpreter"); + + EXPECT_TRUE(json["metadata"].isObject()); + EXPECT_EQ(json["metadata"]["key1"].asString(), "value1"); + EXPECT_EQ(json["metadata"]["key2"].asBool(), true); + EXPECT_EQ(json["metadata"]["key3"].asUInt64(), 42ULL); +} + +TEST_F(AssistantTest, FromJsonMinimal) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.id, "asst_123"); + EXPECT_EQ(assistant.object, "assistant"); + EXPECT_EQ(assistant.created_at, 1702000000); + EXPECT_EQ(assistant.model, "gpt-4"); +} + +TEST_F(AssistantTest, FromJsonComplete) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + input["name"] = "Test Assistant"; + input["description"] = "Test Description"; + input["instructions"] = "Test Instructions"; + input["temperature"] = 0.7; + input["top_p"] = 0.9; + + // Add tools + Json::Value tools(Json::arrayValue); + Json::Value code_tool; + code_tool["type"] = "code_interpreter"; + tools.append(code_tool); + + Json::Value function_tool; + function_tool["type"] = "function"; + function_tool["function"] = Json::Value(Json::objectValue); + function_tool["function"]["name"] = "test_function"; + function_tool["function"]["description"] = "Test function"; + function_tool["function"]["parameters"] = Json::Value(Json::objectValue); + tools.append(function_tool); + input["tools"] = tools; + + // Add metadata + Json::Value metadata(Json::objectValue); + metadata["key1"] = "value1"; + metadata["key2"] = true; + metadata["key3"] = 42; + input["metadata"] = metadata; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.name.value(), "Test Assistant"); + EXPECT_EQ(assistant.description.value(), "Test Description"); + EXPECT_EQ(assistant.instructions.value(), "Test Instructions"); + EXPECT_FLOAT_EQ(assistant.temperature.value(), 0.7f); + EXPECT_FLOAT_EQ(assistant.top_p.value(), 0.9f); + + EXPECT_EQ(assistant.tools.size(), 2); + EXPECT_TRUE(dynamic_cast(assistant.tools[0].get()) != nullptr); + EXPECT_TRUE(dynamic_cast(assistant.tools[1].get()) != nullptr); + + EXPECT_EQ(assistant.metadata.size(), 3); + EXPECT_EQ(std::get(assistant.metadata.at("key1")), "value1"); + EXPECT_EQ(std::get(assistant.metadata.at("key2")), true); + EXPECT_EQ(std::get(assistant.metadata.at("key3")), 42ULL); +} + +TEST_F(AssistantTest, FromJsonInvalidInput) { + // Missing required field 'id' + { + Json::Value input; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid object type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "invalid"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid created_at type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = "invalid"; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } +} + +TEST_F(AssistantTest, MoveConstructorAndAssignment) { + base_assistant.name = "Test Assistant"; + base_assistant.tools.push_back(std::make_unique()); + + // Test move constructor + Assistant moved_assistant(std::move(base_assistant)); + EXPECT_EQ(moved_assistant.id, "asst_123"); + EXPECT_EQ(moved_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(moved_assistant.tools.size(), 1); + + // Test move assignment + Assistant another_assistant; + another_assistant = std::move(moved_assistant); + EXPECT_EQ(another_assistant.id, "asst_123"); + EXPECT_EQ(another_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(another_assistant.tools.size(), 1); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_code_interpreter.cc b/engine/test/components/test_assistant_tool_code_interpreter.cc new file mode 100644 index 000000000..f32526504 --- /dev/null +++ b/engine/test/components/test_assistant_tool_code_interpreter.cc @@ -0,0 +1,49 @@ +#include +#include +#include "common/assistant_code_interpreter_tool.h" + +namespace OpenAi { +namespace { + +class AssistantCodeInterpreterToolTest : public ::testing::Test {}; + +TEST_F(AssistantCodeInterpreterToolTest, BasicConstruction) { + AssistantCodeInterpreterTool tool; + EXPECT_EQ(tool.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveConstructor) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool moved(std::move(original)); + EXPECT_EQ(moved.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveAssignment) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool target; + target = std::move(original); + EXPECT_EQ(target.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, FromJson) { + Json::Value json; // Empty JSON is fine for this tool + auto result = AssistantCodeInterpreterTool::FromJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, ToJson) { + AssistantCodeInterpreterTool tool; + auto result = tool.ToJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value()["type"].asString(), "code_interpreter"); + + // Verify no extra fields + Json::Value::Members members = result.value().getMemberNames(); + EXPECT_EQ(members.size(), 1); // Only "type" field should be present + EXPECT_EQ(members[0], "type"); +} +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_file_search.cc b/engine/test/components/test_assistant_tool_file_search.cc new file mode 100644 index 000000000..25a2ffc05 --- /dev/null +++ b/engine/test/components/test_assistant_tool_file_search.cc @@ -0,0 +1,207 @@ +#include +#include +#include "common/assistant_file_search_tool.h" + +namespace OpenAi { +namespace { + +class AssistantFileSearchToolTest : public ::testing::Test {}; + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionBasicConstruction) { + const float threshold = 0.75f; + const std::string ranker = "test_ranker"; + FileSearchRankingOption option{threshold, ranker}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, ranker); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionDefaultRanker) { + const float threshold = 0.5f; + FileSearchRankingOption option{threshold}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, "auto"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromValidJson) { + Json::Value json; + json["score_threshold"] = 0.8f; + json["ranker"] = "custom_ranker"; + + auto result = FileSearchRankingOption::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().score_threshold, 0.8f); + EXPECT_EQ(result.value().ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromInvalidJson) { + Json::Value json; + auto result = FileSearchRankingOption::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionToJson) { + FileSearchRankingOption option{0.9f, "special_ranker"}; + auto json_result = option.ToJson(); + + ASSERT_TRUE(json_result.has_value()); + Json::Value json = json_result.value(); + + EXPECT_EQ(json["score_threshold"].asFloat(), 0.9f); + EXPECT_EQ(json["ranker"].asString(), "special_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchBasicConstruction) { + FileSearchRankingOption ranking_option{0.7f, "test_ranker"}; + AssistantFileSearch search{10, std::move(ranking_option)}; + + EXPECT_EQ(search.max_num_results, 10); + EXPECT_EQ(search.ranking_options.score_threshold, 0.7f); + EXPECT_EQ(search.ranking_options.ranker, "test_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromValidJson) { + Json::Value json; + json["max_num_results"] = 15; + + Json::Value ranking_json; + ranking_json["score_threshold"] = 0.85f; + ranking_json["ranker"] = "custom_ranker"; + json["ranking_options"] = ranking_json; + + auto result = AssistantFileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().max_num_results, 15); + EXPECT_EQ(result.value().ranking_options.score_threshold, 0.85f); + EXPECT_EQ(result.value().ranking_options.ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearch::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToJson) { + FileSearchRankingOption ranking_option{0.95f, "advanced_ranker"}; + AssistantFileSearch search{20, std::move(ranking_option)}; + + auto json_result = search.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["max_num_results"].asInt(), 20); + EXPECT_EQ(json["ranking_options"]["score_threshold"].asFloat(), 0.95f); + EXPECT_EQ(json["ranking_options"]["ranker"].asString(), "advanced_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolConstruction) { + FileSearchRankingOption ranking_option{0.8f, "tool_ranker"}; + AssistantFileSearch search{25, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + EXPECT_EQ(tool.type, "file_search"); + EXPECT_EQ(tool.file_search.max_num_results, 25); + EXPECT_EQ(tool.file_search.ranking_options.score_threshold, 0.8f); + EXPECT_EQ(tool.file_search.ranking_options.ranker, "tool_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromValidJson) { + Json::Value json; + json["type"] = "file_search"; + + Json::Value file_search; + file_search["max_num_results"] = 30; + + Json::Value ranking_options; + ranking_options["score_threshold"] = 0.75f; + ranking_options["ranker"] = "json_ranker"; + file_search["ranking_options"] = ranking_options; + + json["file_search"] = file_search; + + auto result = AssistantFileSearchTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().type, "file_search"); + EXPECT_EQ(result.value().file_search.max_num_results, 30); + EXPECT_EQ(result.value().file_search.ranking_options.score_threshold, 0.75f); + EXPECT_EQ(result.value().file_search.ranking_options.ranker, "json_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearchTool::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolToJson) { + FileSearchRankingOption ranking_option{0.65f, "final_ranker"}; + AssistantFileSearch search{35, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + auto json_result = tool.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["type"].asString(), "file_search"); + EXPECT_EQ(json["file_search"]["max_num_results"].asInt(), 35); + EXPECT_EQ(json["file_search"]["ranking_options"]["score_threshold"].asFloat(), + 0.65f); + EXPECT_EQ(json["file_search"]["ranking_options"]["ranker"].asString(), + "final_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, MoveConstructorsAndAssignments) { + // Test FileSearchRankingOption move operations + FileSearchRankingOption original_option{0.8f, "original_ranker"}; + FileSearchRankingOption moved_option{std::move(original_option)}; + EXPECT_EQ(moved_option.score_threshold, 0.8f); + EXPECT_EQ(moved_option.ranker, "original_ranker"); + + FileSearchRankingOption assign_target{0.5f}; + assign_target = std::move(moved_option); + EXPECT_EQ(assign_target.score_threshold, 0.8f); + EXPECT_EQ(assign_target.ranker, "original_ranker"); + + // Test AssistantFileSearch move operations + FileSearchRankingOption search_option{0.9f, "search_ranker"}; + AssistantFileSearch original_search{40, std::move(search_option)}; + AssistantFileSearch moved_search{std::move(original_search)}; + EXPECT_EQ(moved_search.max_num_results, 40); + EXPECT_EQ(moved_search.ranking_options.score_threshold, 0.9f); + + // Test AssistantFileSearchTool move operations + FileSearchRankingOption tool_option{0.7f, "tool_ranker"}; + AssistantFileSearch tool_search{45, std::move(tool_option)}; + AssistantFileSearchTool original_tool{tool_search}; + AssistantFileSearchTool moved_tool{std::move(original_tool)}; + EXPECT_EQ(moved_tool.type, "file_search"); + EXPECT_EQ(moved_tool.file_search.max_num_results, 45); +} + +TEST_F(AssistantFileSearchToolTest, EdgeCases) { + // Test boundary values for score_threshold + FileSearchRankingOption min_threshold{0.0f}; + EXPECT_EQ(min_threshold.score_threshold, 0.0f); + + FileSearchRankingOption max_threshold{1.0f}; + EXPECT_EQ(max_threshold.score_threshold, 1.0f); + + // Test boundary values for max_num_results + FileSearchRankingOption ranking_option{0.5f}; + AssistantFileSearch min_results{1, std::move(ranking_option)}; + EXPECT_EQ(min_results.max_num_results, 1); + + FileSearchRankingOption ranking_option2{0.5f}; + AssistantFileSearch max_results{50, std::move(ranking_option2)}; + EXPECT_EQ(max_results.max_num_results, 50); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_function.cc b/engine/test/components/test_assistant_tool_function.cc new file mode 100644 index 000000000..6f59df693 --- /dev/null +++ b/engine/test/components/test_assistant_tool_function.cc @@ -0,0 +1,240 @@ +#include +#include "common/assistant_function_tool.h" +#include + +namespace OpenAi { +namespace { + +class AssistantFunctionTest : public ::testing::Test { +protected: + void SetUp() override { + // Common test setup + basic_description = "Test function description"; + basic_name = "test_function"; + basic_params = Json::Value(Json::objectValue); + basic_params["type"] = "object"; + basic_params["properties"] = Json::Value(Json::objectValue); + } + + std::string basic_description; + std::string basic_name; + Json::Value basic_params; +}; + +TEST_F(AssistantFunctionTest, BasicConstructionWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + EXPECT_FALSE(function.strict.has_value()); +} + +TEST_F(AssistantFunctionTest, BasicConstructionWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, MoveConstructor) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction moved(std::move(original)); + + EXPECT_EQ(moved.description, basic_description); + EXPECT_EQ(moved.name, basic_name); + EXPECT_EQ(moved.parameters, basic_params); + ASSERT_TRUE(moved.strict.has_value()); + EXPECT_TRUE(*moved.strict); +} + +TEST_F(AssistantFunctionTest, MoveAssignment) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction target("other", "other_name", Json::Value(Json::objectValue), false); + target = std::move(original); + + EXPECT_EQ(target.description, basic_description); + EXPECT_EQ(target.name, basic_name); + EXPECT_EQ(target.parameters, basic_params); + ASSERT_TRUE(target.strict.has_value()); + EXPECT_TRUE(*target.strict); +} + +TEST_F(AssistantFunctionTest, FromValidJson) { + Json::Value json; + json["description"] = basic_description; + json["name"] = basic_name; + json["strict"] = true; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& function = result.value(); + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyJson) { + Json::Value json; + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function json can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyName) { + Json::Value json; + json["description"] = basic_description; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); + + // Test with empty name value + json["name"] = ""; + result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingDescription) { + Json::Value json; + json["name"] = basic_name; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function description is mandatory"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingParameters) { + Json::Value json; + json["name"] = basic_name; + json["description"] = basic_description; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function parameters are mandatory"); +} + +TEST_F(AssistantFunctionTest, ToJsonWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_TRUE(json["strict"].asBool()); +} + +TEST_F(AssistantFunctionTest, ToJsonWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_FALSE(json.isMember("strict")); +} + +// AssistantFunctionTool Tests +class AssistantFunctionToolTest : public ::testing::Test { +protected: + void SetUp() override { + description = "Test tool description"; + name = "test_tool"; + params = Json::Value(Json::objectValue); + params["type"] = "object"; + } + + std::string description; + std::string name; + Json::Value params; +}; + +TEST_F(AssistantFunctionToolTest, BasicConstruction) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, MoveConstructor) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool original(function); + + AssistantFunctionTool moved(std::move(original)); + + EXPECT_EQ(moved.type, "function"); + EXPECT_EQ(moved.function.description, description); + EXPECT_EQ(moved.function.name, name); + EXPECT_EQ(moved.function.parameters, params); +} + +TEST_F(AssistantFunctionToolTest, FromValidJson) { + Json::Value function_json; + function_json["description"] = description; + function_json["name"] = name; + function_json["strict"] = true; + function_json["parameters"] = params; + + Json::Value json; + json["type"] = "function"; + json["function"] = function_json; + + auto result = AssistantFunctionTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& tool = result.value(); + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, FromInvalidJson) { + Json::Value json; + auto result = AssistantFunctionTool::FromJson(json); + EXPECT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Failed to parse function: Function json can't be empty"); +} + +TEST_F(AssistantFunctionToolTest, ToJson) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + auto result = tool.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["type"].asString(), "function"); + EXPECT_EQ(json["function"]["description"].asString(), description); + EXPECT_EQ(json["function"]["name"].asString(), name); + EXPECT_EQ(json["function"]["parameters"], params); + EXPECT_TRUE(json["function"]["strict"].asBool()); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_tool_resources.cc b/engine/test/components/test_tool_resources.cc new file mode 100644 index 000000000..2b78e6494 --- /dev/null +++ b/engine/test/components/test_tool_resources.cc @@ -0,0 +1,212 @@ +#include +#include +#include "common/tool_resources.h" + +namespace OpenAi { +namespace { + +// Mock class for testing abstract ToolResources +class MockToolResources : public ToolResources { + public: + cpp::result ToJson() override { + Json::Value json; + json["mock"] = "value"; + return json; + } +}; + +class ToolResourcesTest : public ::testing::Test {}; + +TEST_F(ToolResourcesTest, MoveConstructor) { + MockToolResources original; + MockToolResources moved(std::move(original)); + + auto json_result = moved.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +TEST_F(ToolResourcesTest, MoveAssignment) { + MockToolResources original; + MockToolResources target; + target = std::move(original); + + auto json_result = target.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +class CodeInterpreterTest : public ::testing::Test { + protected: + void SetUp() override { sample_file_ids = {"file1", "file2", "file3"}; } + + std::vector sample_file_ids; +}; + +TEST_F(CodeInterpreterTest, DefaultConstruction) { + CodeInterpreter interpreter; + EXPECT_TRUE(interpreter.file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, MoveConstructor) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter moved(std::move(original)); + EXPECT_EQ(moved.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, MoveAssignment) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter target; + target = std::move(original); + EXPECT_EQ(target.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, FromJsonWithFileIds) { + Json::Value json; + Json::Value file_ids(Json::arrayValue); + for (const auto& id : sample_file_ids) { + file_ids.append(id); + } + json["file_ids"] = file_ids; + + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().file_ids, sample_file_ids); +} + +TEST_F(CodeInterpreterTest, FromJsonWithoutFileIds) { + Json::Value json; // Empty JSON + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, ToJson) { + CodeInterpreter interpreter; + interpreter.file_ids = sample_file_ids; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + ASSERT_EQ(json["file_ids"].size(), sample_file_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["file_ids"].size(); ++i) { + EXPECT_EQ(json["file_ids"][i].asString(), sample_file_ids[i]); + } +} + +TEST_F(CodeInterpreterTest, ToJsonEmptyFileIds) { + CodeInterpreter interpreter; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + EXPECT_EQ(json["file_ids"].size(), 0); +} + +class FileSearchTest : public ::testing::Test { + protected: + void SetUp() override { + sample_vector_store_ids = {"store1", "store2", "store3"}; + } + + std::vector sample_vector_store_ids; +}; + +TEST_F(FileSearchTest, DefaultConstruction) { + FileSearch search; + EXPECT_TRUE(search.vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, MoveConstructor) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch moved(std::move(original)); + EXPECT_EQ(moved.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, MoveAssignment) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch target; + target = std::move(original); + EXPECT_EQ(target.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, FromJsonWithVectorStoreIds) { + Json::Value json; + Json::Value vector_store_ids(Json::arrayValue); + for (const auto& id : sample_vector_store_ids) { + vector_store_ids.append(id); + } + json["vector_store_ids"] = vector_store_ids; + + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().vector_store_ids, sample_vector_store_ids); +} + +TEST_F(FileSearchTest, FromJsonWithoutVectorStoreIds) { + Json::Value json; // Empty JSON + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, ToJson) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + ASSERT_EQ(json["vector_store_ids"].size(), sample_vector_store_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["vector_store_ids"].size(); ++i) { + EXPECT_EQ(json["vector_store_ids"][i].asString(), + sample_vector_store_ids[i]); + } +} + +TEST_F(FileSearchTest, ToJsonEmptyVectorStoreIds) { + FileSearch search; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + EXPECT_EQ(json["vector_store_ids"].size(), 0); +} + +TEST_F(FileSearchTest, SelfAssignment) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + search = std::move(search); // Self-assignment with move + EXPECT_EQ(search.vector_store_ids, sample_vector_store_ids); +} +} // namespace +} // namespace OpenAi From d9bdb81430c6939178801217854b970a3d3b7a0f Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 30 Dec 2024 08:28:40 +0700 Subject: [PATCH 06/16] fix: using engine variant name for download task id (#1833) --- engine/services/engine_service.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 93311f98b..d908aba1b 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -8,7 +8,6 @@ #include "database/engines.h" #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" -#include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" #include "utils/file_manager_utils.h" @@ -364,10 +363,10 @@ cpp::result EngineService::DownloadEngine( }; auto downloadTask = - DownloadTask{.id = engine, + DownloadTask{.id = selected_variant->name, .type = DownloadType::Engine, .items = {DownloadItem{ - .id = engine, + .id = selected_variant->name, .downloadUrl = selected_variant->browser_download_url, .localPath = variant_path, }}}; From 3b545de4df29d81d83aab09b648ecb9930ca70fd Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 30 Dec 2024 11:08:06 +0700 Subject: [PATCH 07/16] chore: database service (#1834) * chore: remove services namespace * chore: db: engines to service * chore: db: file to service * chore: db: hardware to service * chore: db: models to service --------- Co-authored-by: vansangpfiev --- engine/CMakeLists.txt | 2 +- engine/cli/CMakeLists.txt | 1 + engine/cli/command_line_parser.cc | 14 ++- engine/cli/command_line_parser.h | 1 + engine/cli/commands/chat_completion_cmd.cc | 3 +- engine/cli/commands/chat_completion_cmd.h | 4 + engine/cli/commands/model_start_cmd.cc | 2 +- engine/cli/commands/model_start_cmd.h | 9 +- engine/cli/commands/run_cmd.cc | 18 ++- engine/cli/commands/run_cmd.h | 7 +- engine/cli/commands/server_start_cmd.cc | 3 +- engine/controllers/hardware.h | 4 +- engine/controllers/models.cc | 15 +-- engine/controllers/models.h | 11 +- engine/controllers/server.cc | 10 +- engine/controllers/server.h | 8 +- engine/main.cc | 22 ++-- engine/repositories/file_fs_repository.cc | 14 +-- engine/repositories/file_fs_repository.h | 7 +- engine/services/database_service.cc | 130 +++++++++++++++++++++ engine/services/database_service.h | 68 +++++++++++ engine/services/engine_service.cc | 22 ++-- engine/services/engine_service.h | 8 +- engine/services/hardware_service.cc | 22 ++-- engine/services/hardware_service.h | 9 +- engine/services/inference_service.cc | 2 - engine/services/inference_service.h | 3 - engine/services/model_service.cc | 100 +++++++--------- engine/services/model_service.h | 24 ++-- engine/services/model_source_service.cc | 65 +++++------ engine/services/model_source_service.h | 19 ++- 31 files changed, 406 insertions(+), 221 deletions(-) create mode 100644 engine/services/database_service.cc create mode 100644 engine/services/database_service.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 25c0783b1..e82e07aab 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -157,7 +157,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp Drogon::Drogon Ope target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp) target_link_libraries(${TARGET_NAME} PRIVATE eventpp::eventpp) target_link_libraries(${TARGET_NAME} PRIVATE lfreist-hwinfo::hwinfo) - + # ############################################################################## if(CMAKE_CXX_STANDARD LESS 17) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index df4f1a76b..eb29460a7 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -83,6 +83,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../services/database_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 825780895..6f8f227e6 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -49,8 +49,9 @@ CommandLineParser::CommandLineParser() : app_("\nCortex.cpp CLI\n"), download_service_{std::make_shared()}, dylib_path_manager_{std::make_shared()}, - engine_service_{std::make_shared(download_service_, - dylib_path_manager_)} { + db_service_{std::make_shared()}, + engine_service_{std::make_shared( + download_service_, dylib_path_manager_, db_service_)} { supported_engines_ = engine_service_->GetSupportedEngineNames().value(); } @@ -177,7 +178,7 @@ void CommandLineParser::SetupCommonCommands() { return; commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, engine_service_); + cml_data_.model_id, db_service_, engine_service_); rc.Exec(cml_data_.run_detach, run_settings_); }); } @@ -216,9 +217,10 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(model_start_cmd->help()); return; }; - commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, run_settings_); + commands::ModelStartCmd(db_service_) + .Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, + run_settings_); }); auto stop_model_cmd = diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 14e10e420..5b64f7f4d 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -45,6 +45,7 @@ class CommandLineParser { CLI::App app_; std::shared_ptr download_service_; std::shared_ptr dylib_path_manager_; + std::shared_ptr db_service_; std::shared_ptr engine_service_; std::vector supported_engines_; diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index 77d222176..77ee4fca3 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -56,10 +56,9 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, const std::string& model_handle, std::string msg) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CLI_LOG("Error: " + model_entry.error()); return; diff --git a/engine/cli/commands/chat_completion_cmd.h b/engine/cli/commands/chat_completion_cmd.h index a784b4604..44de5d256 100644 --- a/engine/cli/commands/chat_completion_cmd.h +++ b/engine/cli/commands/chat_completion_cmd.h @@ -3,16 +3,20 @@ #include #include #include "config/model_config.h" +#include "services/database_service.h" namespace commands { class ChatCompletionCmd { public: + explicit ChatCompletionCmd(std::shared_ptr db_service) + : db_service_(db_service) {} void Exec(const std::string& host, int port, const std::string& model_handle, std::string msg); void Exec(const std::string& host, int port, const std::string& model_handle, const config::ModelConfig& mc, std::string msg); private: + std::shared_ptr db_service_; std::vector histories_; }; } // namespace commands diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 12aec944d..ef5d5c1f2 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -13,7 +13,7 @@ bool ModelStartCmd::Exec( const std::unordered_map& options, bool print_success_log) { std::optional model_id = - SelectLocalModel(host, port, model_handle); + SelectLocalModel(host, port, model_handle, *db_service_); if (!model_id.has_value()) { return false; diff --git a/engine/cli/commands/model_start_cmd.h b/engine/cli/commands/model_start_cmd.h index 124ef463d..c69bfc32a 100644 --- a/engine/cli/commands/model_start_cmd.h +++ b/engine/cli/commands/model_start_cmd.h @@ -3,16 +3,23 @@ #include #include #include "json/json.h" +#include "services/database_service.h" namespace commands { class ModelStartCmd { public: + explicit ModelStartCmd(std::shared_ptr db_service) + : db_service_(db_service) {} bool Exec(const std::string& host, int port, const std::string& model_handle, const std::unordered_map& options, bool print_success_log = true); - private: + + private: bool UpdateConfig(Json::Value& data, const std::string& key, const std::string& value); + + private: + std::shared_ptr db_service_; }; } // namespace commands diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 91a813d64..c01d3d806 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -14,12 +14,11 @@ namespace commands { std::optional SelectLocalModel(std::string host, int port, - const std::string& model_handle) { + const std::string& model_handle, + DatabaseService& db_service) { std::optional model_id = model_handle; - cortex::db::Models modellist_handler; - if (model_handle.empty()) { - auto all_local_models = modellist_handler.LoadModelList(); + auto all_local_models = db_service.LoadModelList(); if (all_local_models.has_error() || all_local_models.value().empty()) { CLI_LOG("No local models available!"); return std::nullopt; @@ -42,7 +41,7 @@ std::optional SelectLocalModel(std::string host, int port, CLI_LOG("Selected: " << selection.value()); } } else { - auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); + auto related_models_ids = db_service.FindRelatedModel(model_handle); if (related_models_ids.has_error() || related_models_ids.value().empty()) { auto result = ModelPullCmd().Exec(host, port, model_handle); if (!result) { @@ -69,19 +68,18 @@ std::optional SelectLocalModel(std::string host, int port, void RunCmd::Exec(bool run_detach, const std::unordered_map& options) { std::optional model_id = - SelectLocalModel(host_, port_, model_handle_); + SelectLocalModel(host_, port_, model_handle_, *db_service_); if (!model_id.has_value()) { return; } - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); try { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - auto model_entry = modellist_handler.GetModelInfo(*model_id); + auto model_entry = db_service_->GetModelInfo(*model_id); if (model_entry.has_error()) { CLI_LOG("Error: " + model_entry.error()); return; @@ -128,7 +126,7 @@ void RunCmd::Exec(bool run_detach, mc.engine.find(kLlamaEngine) == std::string::npos) || !commands::ModelStatusCmd().IsLoaded(host_, port_, *model_id)) { - auto res = commands::ModelStartCmd() + auto res = commands::ModelStartCmd(db_service_) .Exec(host_, port_, *model_id, options, false /*print_success_log*/); if (!res) { @@ -144,7 +142,7 @@ void RunCmd::Exec(bool run_detach, << commands::GetCortexBinary() << " run " << *model_id << "` for interactive chat shell"); } else { - ChatCompletionCmd().Exec(host_, port_, *model_id, mc, ""); + ChatCompletionCmd(db_service_).Exec(host_, port_, *model_id, mc, ""); } } } catch (const std::exception& e) { diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index b22b064f9..ec5c61fd3 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -2,20 +2,24 @@ #include #include +#include "services/database_service.h" #include "services/engine_service.h" namespace commands { std::optional SelectLocalModel(std::string host, int port, - const std::string& model_handle); + const std::string& model_handle, + DatabaseService& db_service); class RunCmd { public: explicit RunCmd(std::string host, int port, std::string model_handle, + std::shared_ptr db_service, std::shared_ptr engine_service) : host_{std::move(host)}, port_{port}, model_handle_{std::move(model_handle)}, + db_service_(db_service), engine_service_{engine_service} {}; void Exec(bool chat_flag, @@ -25,6 +29,7 @@ class RunCmd { std::string host_; int port_; std::string model_handle_; + std::shared_ptr db_service_; std::shared_ptr engine_service_; }; } // namespace commands diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 3d6045cd5..4268f6362 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -114,7 +114,8 @@ bool ServerStartCmd::Exec(const std::string& host, int port, // Some engines requires to add lib search path before process being created auto download_srv = std::make_shared(); auto dylib_path_mng = std::make_shared(); - EngineService(download_srv, dylib_path_mng).RegisterEngineLibPath(); + auto db_srv = std::make_shared(); + EngineService(download_srv, dylib_path_mng, db_srv).RegisterEngineLibPath(); std::string p = cortex_utils::GetCurrentPath() + "/" + exe; execl(p.c_str(), exe.c_str(), "--start-server", "--config_file_path", diff --git a/engine/controllers/hardware.h b/engine/controllers/hardware.h index 6cca4fd2a..8b2b551ce 100644 --- a/engine/controllers/hardware.h +++ b/engine/controllers/hardware.h @@ -9,7 +9,7 @@ using namespace drogon; class Hardware : public drogon::HttpController { public: explicit Hardware(std::shared_ptr engine_svc, - std::shared_ptr hw_svc) + std::shared_ptr hw_svc) : engine_svc_(engine_svc), hw_svc_(hw_svc) {} METHOD_LIST_BEGIN METHOD_ADD(Hardware::GetHardwareInfo, "/hardware", Get); @@ -27,5 +27,5 @@ class Hardware : public drogon::HttpController { private: std::shared_ptr engine_svc_ = nullptr; - std::shared_ptr hw_svc_= nullptr; + std::shared_ptr hw_svc_= nullptr; }; \ No newline at end of file diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 1c33ab1dc..1a501287d 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -165,10 +165,9 @@ void Models::ListModel( model_service_->ForceIndexingModelList(); // Iterate through directory - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto list_entry = modellist_handler.LoadModelList(); + auto list_entry = db_service_->LoadModelList(); if (list_entry) { for (const auto& model_entry : list_entry.value()) { try { @@ -256,9 +255,8 @@ void Models::GetModel(const HttpRequestPtr& req, Json::Value ret; try { - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto model_entry = modellist_handler.GetModelInfo(model_id); + auto model_entry = db_service_->GetModelInfo(model_id); if (model_entry.has_error()) { ret["id"] = model_id; ret["object"] = "model"; @@ -337,8 +335,7 @@ void Models::UpdateModel(const HttpRequestPtr& req, namespace fmu = file_manager_utils; auto json_body = *(req->getJsonObject()); try { - cortex::db::Models model_list_utils; - auto model_entry = model_list_utils.GetModelInfo(model_id); + auto model_entry = db_service_->GetModelInfo(model_id); config::YamlHandler yaml_handler; auto yaml_fp = fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)); @@ -401,7 +398,6 @@ void Models::ImportModel( auto option = (*(req->getJsonObject())).get("option", "symlink").asString(); config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; - cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("imported") / std::filesystem::path(modelHandle + ".yml")) @@ -440,7 +436,7 @@ void Models::ImportModel( model_config.name = modelName.empty() ? model_config.name : modelName; yaml_handler.UpdateModelConfig(model_config); - if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + if (db_service_->AddModelEntry(model_entry).value()) { yaml_handler.WriteYamlFile(model_yaml_path); std::string success_message = "Model is imported successfully!"; LOG_INFO << success_message; @@ -667,7 +663,6 @@ void Models::AddRemoteModel( config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); - cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("remote") / std::filesystem::path(model_handle + ".yml")) @@ -683,7 +678,7 @@ void Models::AddRemoteModel( "openai"}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); - if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + if (db_service_->AddModelEntry(model_entry).value()) { model_config.SaveToYamlFile(model_yaml_path); std::string success_message = "Model is imported successfully!"; LOG_INFO << success_message; diff --git a/engine/controllers/models.h b/engine/controllers/models.h index d3200f33a..60053acdb 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -45,10 +45,12 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::GetModelSources, "/v1/models/sources", Get); METHOD_LIST_END - explicit Models(std::shared_ptr model_service, + explicit Models(std::shared_ptr db_service, + std::shared_ptr model_service, std::shared_ptr engine_service, - std::shared_ptr mss) - : model_service_{model_service}, + std::shared_ptr mss) + : db_service_(db_service), + model_service_{model_service}, engine_service_{engine_service}, model_src_svc_(mss) {} @@ -105,7 +107,8 @@ class Models : public drogon::HttpController { std::function&& callback); private: + std::shared_ptr db_service_; std::shared_ptr model_service_; std::shared_ptr engine_service_; - std::shared_ptr model_src_svc_; + std::shared_ptr model_src_svc_; }; diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 19842bcdb..d8e29eb1b 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -8,7 +8,7 @@ using namespace inferences; namespace inferences { -server::server(std::shared_ptr inference_service, +server::server(std::shared_ptr inference_service, std::shared_ptr engine_service) : inference_svc_(inference_service), engine_service_(engine_service) { #if defined(_WIN32) @@ -45,7 +45,7 @@ void server::ChatCompletion( }(); LOG_DEBUG << "request body: " << json_body->toStyledString(); - auto q = std::make_shared(); + auto q = std::make_shared(); auto ir = inference_svc_->HandleChatCompletion(q, json_body); if (ir.has_error()) { auto err = ir.error(); @@ -67,7 +67,7 @@ void server::ChatCompletion( void server::Embedding(const HttpRequestPtr& req, std::function&& callback) { LOG_TRACE << "Start embedding"; - auto q = std::make_shared(); + auto q = std::make_shared(); auto ir = inference_svc_->HandleEmbedding(q, req->getJsonObject()); if (ir.has_error()) { auto err = ir.error(); @@ -138,7 +138,7 @@ void server::LoadModel(const HttpRequestPtr& req, } void server::ProcessStreamRes(std::function cb, - std::shared_ptr q, + std::shared_ptr q, const std::string& engine_type, const std::string& model_id) { auto err_or_done = std::make_shared(false); @@ -178,7 +178,7 @@ void server::ProcessStreamRes(std::function cb, } void server::ProcessNonStreamRes(std::function cb, - services::SyncQueue& q) { + SyncQueue& q) { auto [status, res] = q.wait_and_pop(); function_calling_utils::PostProcessResponse(res); LOG_DEBUG << "response: " << res.toStyledString(); diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 22ea86c30..ef8a32f5d 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -27,7 +27,7 @@ class server : public drogon::HttpController, public BaseChatCompletion, public BaseEmbedding { public: - server(std::shared_ptr inference_service, + server(std::shared_ptr inference_service, std::shared_ptr engine_service); ~server(); METHOD_LIST_BEGIN @@ -72,14 +72,14 @@ class server : public drogon::HttpController, private: void ProcessStreamRes(std::function cb, - std::shared_ptr q, + std::shared_ptr q, const std::string& engine_type, const std::string& model_id); void ProcessNonStreamRes(std::function cb, - services::SyncQueue& q); + SyncQueue& q); private: - std::shared_ptr inference_svc_; + std::shared_ptr inference_svc_; std::shared_ptr engine_service_; }; }; // namespace inferences diff --git a/engine/main.cc b/engine/main.cc index 938392bf0..77f51c7fa 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -21,6 +21,7 @@ #include "repositories/thread_fs_repository.h" #include "services/assistant_service.h" #include "services/config_service.h" +#include "services/database_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" @@ -120,7 +121,8 @@ void RunServer(std::optional host, std::optional port, LOG_INFO << "cortex.cpp version: undefined"; #endif - auto hw_service = std::make_shared(); + auto db_service = std::make_shared(); + auto hw_service = std::make_shared(db_service); hw_service->UpdateHardwareInfos(); if (hw_service->ShouldRestart()) { CTL_INF("Restart to update hardware configuration"); @@ -140,7 +142,8 @@ void RunServer(std::optional host, std::optional port, // utils auto dylib_path_manager = std::make_shared(); - auto file_repo = std::make_shared(data_folder_path); + auto file_repo = + std::make_shared(data_folder_path, db_service); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); auto assistant_repo = @@ -156,13 +159,12 @@ void RunServer(std::optional host, std::optional port, auto config_service = std::make_shared(); auto download_service = std::make_shared(event_queue_ptr, config_service); - auto engine_service = - std::make_shared(download_service, dylib_path_manager); - auto inference_svc = - std::make_shared(engine_service); - auto model_src_svc = std::make_shared(); + auto engine_service = std::make_shared( + download_service, dylib_path_manager, db_service); + auto inference_svc = std::make_shared(engine_service); + auto model_src_svc = std::make_shared(db_service); auto model_service = std::make_shared( - download_service, inference_svc, engine_service); + db_service, hw_service, download_service, inference_svc, engine_service); inference_svc->SetModelService(model_service); auto file_watcher_srv = std::make_shared( @@ -177,8 +179,8 @@ void RunServer(std::optional host, std::optional port, auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); - auto model_ctl = - std::make_shared(model_service, engine_service, model_src_svc); + auto model_ctl = std::make_shared(db_service, model_service, + engine_service, model_src_svc); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); auto hw_ctl = std::make_shared(engine_service, hw_service); diff --git a/engine/repositories/file_fs_repository.cc b/engine/repositories/file_fs_repository.cc index a209d33c3..e6c28b38e 100644 --- a/engine/repositories/file_fs_repository.cc +++ b/engine/repositories/file_fs_repository.cc @@ -17,7 +17,6 @@ cpp::result FileFsRepository::StoreFile( std::filesystem::create_directories(file_container_path); } - cortex::db::File db; auto original_filename = file_metadata.filename; auto file_full_path = file_container_path / original_filename; @@ -53,7 +52,7 @@ cpp::result FileFsRepository::StoreFile( file.flush(); file.close(); - auto result = db.AddFileEntry(file_metadata); + auto result = db_service_->AddFileEntry(file_metadata); if (result.has_error()) { std::filesystem::remove(file_full_path); return cpp::fail(result.error()); @@ -70,8 +69,7 @@ cpp::result FileFsRepository::StoreFile( cpp::result, std::string> FileFsRepository::ListFiles( const std::string& purpose, uint8_t limit, const std::string& order, const std::string& after) const { - cortex::db::File db; - auto res = db.GetFileList(); + auto res = db_service_->GetFileList(); if (res.has_error()) { return cpp::fail(res.error()); } @@ -101,8 +99,7 @@ cpp::result FileFsRepository::RetrieveFile( CTL_INF("Retrieving file: " + file_id); auto file_container_path = GetFilePath(); - cortex::db::File db; - auto res = db.GetFileById(file_id); + auto res = db_service_->GetFileById(file_id); if (res.has_error()) { return cpp::fail(res.error()); } @@ -158,15 +155,14 @@ cpp::result FileFsRepository::DeleteFileLocal( const std::string& file_id) { CTL_INF("Deleting file: " + file_id); auto file_container_path = GetFilePath(); - cortex::db::File db; - auto file_metadata = db.GetFileById(file_id); + auto file_metadata = db_service_->GetFileById(file_id); if (file_metadata.has_error()) { return cpp::fail(file_metadata.error()); } auto file_path = file_container_path / file_metadata->filename; - auto res = db.DeleteFileEntry(file_id); + auto res = db_service_->DeleteFileEntry(file_id); if (res.has_error()) { CTL_ERR("Failed to delete file entry: " << res.error()); return cpp::fail(res.error()); diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 77af60dfc..e2ad424a7 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -2,6 +2,7 @@ #include #include "common/repository/file_repository.h" +#include "services/database_service.h" #include "utils/logging_utils.h" class FileFsRepository : public FileRepository { @@ -28,8 +29,9 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(const std::filesystem::path& data_folder_path) - : data_folder_path_{data_folder_path} { + explicit FileFsRepository(const std::filesystem::path& data_folder_path, + std::shared_ptr db_service) + : data_folder_path_{data_folder_path}, db_service_(db_service) { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; @@ -47,4 +49,5 @@ class FileFsRepository : public FileRepository { * The path to the data folder. */ std::filesystem::path data_folder_path_; + std::shared_ptr db_service_ = nullptr; }; diff --git a/engine/services/database_service.cc b/engine/services/database_service.cc new file mode 100644 index 000000000..d4cd977a9 --- /dev/null +++ b/engine/services/database_service.cc @@ -0,0 +1,130 @@ +#include "database_service.h" + +// begin engines +std::optional DatabaseService::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + return cortex::db::Engines().UpsertEngine(engine_name, type, api_key, url, + version, variant, status, metadata); +} + +std::optional> DatabaseService::GetEngines() const { + return cortex::db::Engines().GetEngines(); +} + +std::optional DatabaseService::GetEngineById(int id) const { + return cortex::db::Engines().GetEngineById(id); +} + +std::optional DatabaseService::GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant) const { + return cortex::db::Engines().GetEngineByNameAndVariant(engine_name, variant); +} + +std::optional DatabaseService::DeleteEngineById(int id) { + return cortex::db::Engines().DeleteEngineById(id); +} +// end engines + +// begin file +cpp::result, std::string> +DatabaseService::GetFileList() const { + return cortex::db::File().GetFileList(); +} + +cpp::result DatabaseService::GetFileById( + const std::string& file_id) const { + return cortex::db::File().GetFileById(file_id); +} + +cpp::result DatabaseService::AddFileEntry( + OpenAi::File& file) { + return cortex::db::File().AddFileEntry(file); +} + +cpp::result DatabaseService::DeleteFileEntry( + const std::string& file_id) { + return cortex::db::File().DeleteFileEntry(file_id); +} +// end file + +// begin hardware +cpp::result, std::string> +DatabaseService::LoadHardwareList() const { + return cortex::db::Hardware().LoadHardwareList(); +} + +cpp::result DatabaseService::AddHardwareEntry( + const HardwareEntry& new_entry) { + return cortex::db::Hardware().AddHardwareEntry(new_entry); +} + +cpp::result DatabaseService::UpdateHardwareEntry( + const std::string& id, const HardwareEntry& updated_entry) { + return cortex::db::Hardware().UpdateHardwareEntry(id, updated_entry); +} + +cpp::result DatabaseService::DeleteHardwareEntry( + const std::string& id) { + return cortex::db::Hardware().DeleteHardwareEntry(id); +} +// end hardware + +// begin models +cpp::result, std::string> +DatabaseService::LoadModelList() const { + return cortex::db::Models().LoadModelList(); +} + +cpp::result DatabaseService::GetModelInfo( + const std::string& identifier) const { + return cortex::db::Models().GetModelInfo(identifier); +} + +cpp::result DatabaseService::AddModelEntry( + ModelEntry new_entry) { + return cortex::db::Models().AddModelEntry(new_entry); +} + +cpp::result DatabaseService::UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry) { + return cortex::db::Models().UpdateModelEntry(identifier, updated_entry); +} + +cpp::result DatabaseService::DeleteModelEntry( + const std::string& identifier) { + return cortex::db::Models().DeleteModelEntry(identifier); +} + +cpp::result DatabaseService::DeleteModelEntryWithOrg( + const std::string& src) { + return cortex::db::Models().DeleteModelEntryWithOrg(src); +} + +cpp::result DatabaseService::DeleteModelEntryWithRepo( + const std::string& src) { + return cortex::db::Models().DeleteModelEntryWithRepo(src); +} + +cpp::result, std::string> +DatabaseService::FindRelatedModel(const std::string& identifier) const { + return cortex::db::Models().FindRelatedModel(identifier); +} + +bool DatabaseService::HasModel(const std::string& identifier) const { + return cortex::db::Models().HasModel(identifier); +} + +cpp::result, std::string> +DatabaseService::GetModelSources() const { + return cortex::db::Models().GetModelSources(); +} + +cpp::result, std::string> DatabaseService::GetModels( + const std::string& model_src) const { + return cortex::db::Models().GetModels(model_src); +} +// end models \ No newline at end of file diff --git a/engine/services/database_service.h b/engine/services/database_service.h new file mode 100644 index 000000000..4fb4f7be0 --- /dev/null +++ b/engine/services/database_service.h @@ -0,0 +1,68 @@ +#pragma once +#include "database/engines.h" +#include "database/file.h" +#include "database/hardware.h" +#include "database/models.h" + +using EngineEntry = cortex::db::EngineEntry; +using HardwareEntry = cortex::db::HardwareEntry; +using ModelEntry = cortex::db::ModelEntry; + +class DatabaseService { + public: + // engines + std::optional UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::optional> GetEngines() const; + std::optional GetEngineById(int id) const; + std::optional GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) const; + + std::optional DeleteEngineById(int id); + + // file + cpp::result, std::string> GetFileList() const; + + cpp::result GetFileById( + const std::string& file_id) const; + + cpp::result AddFileEntry(OpenAi::File& file); + + cpp::result DeleteFileEntry(const std::string& file_id); + + // hardware + cpp::result, std::string> LoadHardwareList() const; + cpp::result AddHardwareEntry( + const HardwareEntry& new_entry); + cpp::result UpdateHardwareEntry( + const std::string& id, const HardwareEntry& updated_entry); + cpp::result DeleteHardwareEntry(const std::string& id); + + // models + cpp::result, std::string> LoadModelList() const; + cpp::result GetModelInfo( + const std::string& identifier) const; + void PrintModelInfo(const ModelEntry& entry) const; + cpp::result AddModelEntry(ModelEntry new_entry); + cpp::result UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry); + cpp::result DeleteModelEntry( + const std::string& identifier); + cpp::result DeleteModelEntryWithOrg( + const std::string& src); + cpp::result DeleteModelEntryWithRepo( + const std::string& src); + cpp::result, std::string> FindRelatedModel( + const std::string& identifier) const; + bool HasModel(const std::string& identifier) const; + cpp::result, std::string> GetModelSources() const; + cpp::result, std::string> GetModels( + const std::string& model_src) const; + + private: +}; \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index d908aba1b..53a4bfa65 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -1031,8 +1031,8 @@ cpp::result EngineService::UpdateEngine( cpp::result, std::string> EngineService::GetEngines() { - cortex::db::Engines engines; - auto get_res = engines.GetEngines(); + assert(db_service_); + auto get_res = db_service_->GetEngines(); if (!get_res.has_value()) { return cpp::fail("Failed to get engine entries"); @@ -1043,8 +1043,8 @@ EngineService::GetEngines() { cpp::result EngineService::GetEngineById( int id) { - cortex::db::Engines engines; - auto get_res = engines.GetEngineById(id); + assert(db_service_); + auto get_res = db_service_->GetEngineById(id); if (!get_res.has_value()) { return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); @@ -1057,8 +1057,8 @@ cpp::result EngineService::GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant) { - cortex::db::Engines engines; - auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + assert(db_service_); + auto get_res = db_service_->GetEngineByNameAndVariant(engine_name, variant); if (!get_res.has_value()) { if (variant.has_value()) { @@ -1077,9 +1077,9 @@ cpp::result EngineService::UpsertEngine( const std::string& api_key, const std::string& url, const std::string& version, const std::string& variant, const std::string& status, const std::string& metadata) { - cortex::db::Engines engines; - auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, - version, variant, status, metadata); + assert(db_service_); + auto upsert_res = db_service_->UpsertEngine( + engine_name, type, api_key, url, version, variant, status, metadata); if (upsert_res.has_value()) { return upsert_res.value(); } else { @@ -1088,8 +1088,8 @@ cpp::result EngineService::UpsertEngine( } std::string EngineService::DeleteEngine(int id) { - cortex::db::Engines engines; - auto delete_res = engines.DeleteEngineById(id); + assert(db_service_); + auto delete_res = db_service_->DeleteEngineById(id); if (delete_res.has_value()) { return delete_res.value(); } else { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 8ead4f6d6..fcd3fdda9 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -12,6 +12,7 @@ #include "cortex-common/cortexpythoni.h" #include "cortex-common/remote_enginei.h" #include "database/engines.h" +#include "services/database_service.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -59,16 +60,19 @@ class EngineService : public EngineServiceI { std::string cuda_driver_version; }; HardwareInfo hw_inf_; + std::shared_ptr db_service_ = nullptr; public: explicit EngineService( std::shared_ptr download_service, - std::shared_ptr dylib_path_manager) + std::shared_ptr dylib_path_manager, + std::shared_ptr db_service) : download_service_{download_service}, dylib_path_manager_{dylib_path_manager}, hw_inf_{.sys_inf = system_info_utils::GetSystemInfo(), .cuda_driver_version = - system_info_utils::GetDriverAndCudaVersion().second} {} + system_info_utils::GetDriverAndCudaVersion().second}, + db_service_(db_service) {} std::vector GetEngineInfoList() const; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index ca2bd8ed9..5552aca56 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -11,8 +11,6 @@ #include "database/hardware.h" #include "utils/cortex_utils.h" -namespace services { - namespace { bool TryConnectToServer(const std::string& host, int port) { constexpr const auto kMaxRetry = 4u; @@ -34,9 +32,8 @@ bool TryConnectToServer(const std::string& host, int port) { HardwareInfo HardwareService::GetHardwareInfo() { // append active state - cortex::db::Hardware hw_db; auto gpus = cortex::hw::GetGPUInfo(); - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { // Only a few elements, brute-force is enough for (auto& entry : res.value()) { @@ -210,7 +207,6 @@ bool HardwareService::SetActivateHardwareConfig( const cortex::hw::ActivateHardwareConfig& ahc) { // Note: need to map software_id and hardware_id // Update to db - cortex::db::Hardware hw_db; // copy all gpu information to new vector auto ahc_gpus = ahc.gpus; auto activate = [&ahc](int software_id) { @@ -225,7 +221,7 @@ bool HardwareService::SetActivateHardwareConfig( return INT_MAX; }; - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { bool need_update = false; std::vector> activated_ids; @@ -258,7 +254,7 @@ bool HardwareService::SetActivateHardwareConfig( for (auto& e : res.value()) { e.activated = activate(e.software_id); e.priority = priority(e.software_id); - auto res = hw_db.UpdateHardwareEntry(e.uuid, e); + auto res = db_service_->UpdateHardwareEntry(e.uuid, e); if (res.has_error()) { CTL_WRN(res.error()); } @@ -271,8 +267,7 @@ bool HardwareService::SetActivateHardwareConfig( void HardwareService::UpdateHardwareInfos() { using HwEntry = cortex::db::HardwareEntry; auto gpus = cortex::hw::GetGPUInfo(); - cortex::db::Hardware hw_db; - auto b = hw_db.LoadHardwareList(); + auto b = db_service_->LoadHardwareList(); std::vector> activated_gpu_bf; std::string debug_b; for (auto const& he : b.value()) { @@ -285,7 +280,8 @@ void HardwareService::UpdateHardwareInfos() { for (auto const& gpu : gpus) { // ignore error // Note: only support NVIDIA for now, so hardware_id = software_id - auto res = hw_db.AddHardwareEntry(HwEntry{.uuid = gpu.uuid, + auto res = + db_service_->AddHardwareEntry(HwEntry{.uuid = gpu.uuid, .type = "gpu", .hardware_id = std::stoi(gpu.id), .software_id = std::stoi(gpu.id), @@ -296,7 +292,7 @@ void HardwareService::UpdateHardwareInfos() { } } - auto a = hw_db.LoadHardwareList(); + auto a = db_service_->LoadHardwareList(); std::vector a_gpu; std::vector> activated_gpu_af; std::string debug_a; @@ -350,11 +346,10 @@ bool HardwareService::IsValidConfig( const cortex::hw::ActivateHardwareConfig& ahc) { if (ahc.gpus.empty()) return true; - cortex::db::Hardware hw_db; auto is_valid = [&ahc](int software_id) { return std::count(ahc.gpus.begin(), ahc.gpus.end(), software_id) > 0; }; - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { for (auto const& e : res.value()) { if (is_valid(e.software_id)) { @@ -364,4 +359,3 @@ bool HardwareService::IsValidConfig( } return false; } -} // namespace services diff --git a/engine/services/hardware_service.h b/engine/services/hardware_service.h index 48ab7a4b1..ad9d70233 100644 --- a/engine/services/hardware_service.h +++ b/engine/services/hardware_service.h @@ -4,6 +4,7 @@ #include #include "common/hardware_config.h" +#include "database_service.h" #include "utils/hardware/cpu_info.h" #include "utils/hardware/gpu_info.h" #include "utils/hardware/os_info.h" @@ -11,8 +12,6 @@ #include "utils/hardware/ram_info.h" #include "utils/hardware/storage_info.h" -namespace services { - struct HardwareInfo { cortex::hw::CPU cpu; cortex::hw::OS os; @@ -24,6 +23,8 @@ struct HardwareInfo { class HardwareService { public: + explicit HardwareService(std::shared_ptr db_service) + : db_service_(db_service) {} HardwareInfo GetHardwareInfo(); bool Restart(const std::string& host, int port); bool SetActivateHardwareConfig(const cortex::hw::ActivateHardwareConfig& ahc); @@ -32,6 +33,6 @@ class HardwareService { bool IsValidConfig(const cortex::hw::ActivateHardwareConfig& ahc); private: + std::shared_ptr db_service_ = nullptr; std::optional ahc_; -}; -} // namespace services +}; \ No newline at end of file diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 08107562b..9d8e9f4f8 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -4,7 +4,6 @@ #include "utils/function_calling/common.h" #include "utils/jinja_utils.h" -namespace services { cpp::result InferenceService::HandleChatCompletion( std::shared_ptr q, std::shared_ptr json_body) { std::string engine_type; @@ -337,4 +336,3 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } -} // namespace services diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 54bc9dc29..75b07b1a3 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -7,8 +7,6 @@ #include "services/model_service.h" #include "utils/result.hpp" -namespace services { - // Status and result using InferResult = std::pair; @@ -68,4 +66,3 @@ class InferenceService { std::shared_ptr engine_service_; std::weak_ptr model_service_; }; -} // namespace services diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index be0eb12a7..2d69e0f17 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -21,7 +21,8 @@ #include "utils/widechar_conv.h" namespace { -void ParseGguf(const DownloadItem& ggufDownloadItem, +void ParseGguf(DatabaseService& db_service, + const DownloadItem& ggufDownloadItem, std::optional author, std::optional name, std::optional size) { @@ -64,8 +65,7 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, CTL_INF("path_to_model_yaml: " << rel.string()); auto author_id = author.has_value() ? author.value() : "cortexso"; - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(ggufDownloadItem.id)) { + if (!db_service.HasModel(ggufDownloadItem.id)) { cortex::db::ModelEntry model_entry{ .model = ggufDownloadItem.id, .author_repo_id = author_id, @@ -73,18 +73,17 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, .path_to_model_yaml = rel.string(), .model_alias = ggufDownloadItem.id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service.AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(ggufDownloadItem.id); + if (auto m = db_service.GetModelInfo(ggufDownloadItem.id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = - modellist_utils_obj.UpdateModelEntry(ggufDownloadItem.id, upd_m); + if (auto r = db_service.UpdateModelEntry(ggufDownloadItem.id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -137,10 +136,9 @@ cpp::result GetDownloadTask( void ModelService::ForceIndexingModelList() { CTL_INF("Force indexing model list"); - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto list_entry = modellist_handler.LoadModelList(); + auto list_entry = db_service_->LoadModelList(); if (list_entry.has_error()) { CTL_ERR("Failed to load model list: " << list_entry.error()); return; @@ -164,8 +162,7 @@ void ModelService::ForceIndexingModelList() { yaml_handler.Reset(); } catch (const std::exception& e) { // remove in db - auto remove_result = - modellist_handler.DeleteModelEntry(model_entry.model); + auto remove_result = db_service_->DeleteModelEntry(model_entry.model); // silently ignore result } } @@ -218,10 +215,8 @@ cpp::result ModelService::HandleCortexsoModel( auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName); - cortex::db::Models modellist_handler; - auto downloaded_model_ids = - modellist_handler.FindRelatedModel(modelName).value_or( - std::vector{}); + auto downloaded_model_ids = db_service_->FindRelatedModel(modelName).value_or( + std::vector{}); std::vector avai_download_opts{}; for (const auto& branch : branches.value()) { @@ -261,9 +256,8 @@ cpp::result ModelService::HandleCortexsoModel( std::optional ModelService::GetDownloadedModel( const std::string& modelId) const { - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto model_entry = modellist_handler.GetModelInfo(modelId); + auto model_entry = db_service_->GetModelInfo(modelId); if (!model_entry.has_value()) { return std::nullopt; } @@ -310,7 +304,6 @@ cpp::result ModelService::HandleDownloadUrlAsync( } std::string huggingFaceHost{kHuggingFaceHost}; - cortex::db::Models modellist_handler; std::string unique_model_id = ""; if (temp_model_id.has_value()) { unique_model_id = temp_model_id.value(); @@ -318,7 +311,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( unique_model_id = author + ":" + model_id + ":" + file_name; } - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value() && model_entry->status == cortex::db::ModelStatus::Downloaded) { CLI_LOG("Model already downloaded: " << unique_model_id); @@ -346,14 +339,15 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [author, temp_name](const DownloadTask& finishedTask) { + auto on_finished = [this, author, + temp_name](const DownloadTask& finishedTask) { // Sum downloadedBytes from all items uint64_t model_size = 0; for (const auto& item : finishedTask.items) { model_size = model_size + item.bytes.value_or(0); } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, temp_name, model_size); + ParseGguf(*db_service_, gguf_download_item, author, temp_name, model_size); }; downloadTask.id = unique_model_id; @@ -366,11 +360,10 @@ ModelService::GetEstimation(const std::string& model_handle, int n_ubatch) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -384,8 +377,8 @@ ModelService::GetEstimation(const std::string& model_handle, fs::path(model_entry.value().path_to_model_yaml)) .string()); auto mc = yaml_handler.GetModelConfig(); - services::HardwareService hw_svc; - auto hw_info = hw_svc.GetHardwareInfo(); + assert(hw_service_); + auto hw_info = hw_service_->GetHardwareInfo(); auto free_vram_MiB = 0u; for (const auto& gpu : hw_info.gpus) { free_vram_MiB += gpu.free_vram; @@ -438,8 +431,7 @@ cpp::result ModelService::HandleUrl( std::string huggingFaceHost{kHuggingFaceHost}; std::string unique_model_id{author + ":" + model_id + ":" + file_name}; - cortex::db::Models modellist_handler; - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value()) { CLI_LOG("Model already downloaded: " << unique_model_id); @@ -467,14 +459,14 @@ cpp::result ModelService::HandleUrl( .localPath = local_path, }}}}; - auto on_finished = [author](const DownloadTask& finishedTask) { + auto on_finished = [this, author](const DownloadTask& finishedTask) { // Sum downloadedBytes from all items uint64_t model_size = 0; for (const auto& item : finishedTask.items) { model_size = model_size + item.bytes.value_or(0); } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, std::nullopt, model_size); + ParseGguf(*db_service_, gguf_download_item, author, std::nullopt, model_size); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished); @@ -488,7 +480,7 @@ cpp::result ModelService::HandleUrl( } bool ModelService::HasModel(const std::string& id) const { - return cortex::db::Models().HasModel(id); + return db_service_->HasModel(id); } cpp::result @@ -501,7 +493,6 @@ ModelService::DownloadModelFromCortexsoAsync( return cpp::fail(download_task.error()); } - cortex::db::Models modellist_handler; std::string unique_model_id = ""; if (temp_model_id.has_value()) { unique_model_id = temp_model_id.value(); @@ -509,13 +500,13 @@ ModelService::DownloadModelFromCortexsoAsync( unique_model_id = name + ":" + branch; } - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value() && model_entry->status == cortex::db::ModelStatus::Downloaded) { return cpp::fail("Please delete the model before downloading again"); } - auto on_finished = [unique_model_id, + auto on_finished = [this, unique_model_id, branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -551,8 +542,7 @@ ModelService::DownloadModelFromCortexsoAsync( file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); CTL_INF("path_to_model_yaml: " << rel.string()); - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(unique_model_id)) { + if (!db_service_->HasModel(unique_model_id)) { cortex::db::ModelEntry model_entry{ .model = unique_model_id, .author_repo_id = "cortexso", @@ -560,18 +550,16 @@ ModelService::DownloadModelFromCortexsoAsync( .path_to_model_yaml = rel.string(), .model_alias = unique_model_id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service_->AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(unique_model_id); - m.has_value()) { + if (auto m = db_service_->GetModelInfo(unique_model_id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = - modellist_utils_obj.UpdateModelEntry(unique_model_id, upd_m); + if (auto r = db_service_->UpdateModelEntry(unique_model_id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -595,7 +583,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [branch, model_id](const DownloadTask& finishedTask) { + auto on_finished = [this, branch, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -622,8 +610,7 @@ cpp::result ModelService::DownloadModelFromCortexso( file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); CTL_INF("path_to_model_yaml: " << rel.string()); - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(model_id)) { + if (!db_service_->HasModel(model_id)) { cortex::db::ModelEntry model_entry{ .model = model_id, .author_repo_id = "cortexso", @@ -631,16 +618,16 @@ cpp::result ModelService::DownloadModelFromCortexso( .path_to_model_yaml = rel.string(), .model_alias = model_id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service_->AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(model_id); m.has_value()) { + if (auto m = db_service_->GetModelInfo(model_id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = modellist_utils_obj.UpdateModelEntry(model_id, upd_m); + if (auto r = db_service_->UpdateModelEntry(model_id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -694,7 +681,6 @@ cpp::result ModelService::DeleteModel( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto result = StopModel(model_handle); @@ -706,7 +692,7 @@ cpp::result ModelService::DeleteModel( } try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -737,7 +723,7 @@ cpp::result ModelService::DeleteModel( } // update model.list - if (modellist_handler.DeleteModelEntry(model_handle)) { + if (db_service_->DeleteModelEntry(model_handle)) { return {}; } else { return cpp::fail("Could not delete model: " + model_handle); @@ -753,7 +739,6 @@ cpp::result ModelService::StartModel( bool bypass_model_check) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; std::optional custom_prompt_template; std::optional ctx_len; @@ -771,7 +756,7 @@ cpp::result ModelService::StartModel( Json::Value json_data; // Currently we don't support download vision models, so we need to bypass check if (!bypass_model_check) { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -910,7 +895,6 @@ cpp::result ModelService::StopModel( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { @@ -918,7 +902,7 @@ cpp::result ModelService::StopModel( bypass_stop_check_set_.end()); std::string engine_name = ""; if (!bypass_check) { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -958,11 +942,10 @@ cpp::result ModelService::GetModelStatus( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -1083,8 +1066,7 @@ cpp::result ModelService::GetModelPullInfo( auto default_model_branch = huggingface_utils::GetDefaultBranch(model_name); - cortex::db::Models modellist_handler; - auto downloaded_model_ids = modellist_handler.FindRelatedModel(model_name) + auto downloaded_model_ids = db_service_->FindRelatedModel(model_name) .value_or(std::vector{}); std::vector avai_download_opts{}; @@ -1128,8 +1110,8 @@ cpp::result, std::string> ModelService::MayFallbackToCpu(const std::string& model_path, int ngl, int ctx_len, int n_batch, int n_ubatch, const std::string& kv_cache_type) { - services::HardwareService hw_svc; - auto hw_info = hw_svc.GetHardwareInfo(); + assert(hw_service_); + auto hw_info = hw_service_->GetHardwareInfo(); assert(!!engine_svc_); auto default_engine = engine_svc_->GetDefaultEngineVariant(kLlamaEngine); bool is_cuda = false; diff --git a/engine/services/model_service.h b/engine/services/model_service.h index ab3596812..cc659fea5 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -6,12 +6,12 @@ #include "common/engine_servicei.h" #include "common/model_metadata.h" #include "config/model_config.h" +#include "services/database_service.h" #include "services/download_service.h" +#include "services/hardware_service.h" #include "utils/hardware/gguf/gguf_file_estimate.h" -namespace services { class InferenceService; -} struct ModelPullInfo { std::string id; @@ -31,14 +31,14 @@ class ModelService { public: void ForceIndexingModelList(); - explicit ModelService(std::shared_ptr download_service) - : download_service_{download_service} {}; - - explicit ModelService( - std::shared_ptr download_service, - std::shared_ptr inference_service, - std::shared_ptr engine_svc) - : download_service_{download_service}, + explicit ModelService(std::shared_ptr db_service, + std::shared_ptr hw_service, + std::shared_ptr download_service, + std::shared_ptr inference_service, + std::shared_ptr engine_svc) + : db_service_(db_service), + hw_service_(hw_service), + download_service_{download_service}, inference_svc_(inference_service), engine_svc_(engine_svc) {}; @@ -115,8 +115,10 @@ class ModelService { const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048, int n_ubatch = 2048, const std::string& kv_cache_type = "f16"); + std::shared_ptr db_service_; + std::shared_ptr hw_service_; std::shared_ptr download_service_; - std::shared_ptr inference_svc_; + std::shared_ptr inference_svc_; std::unordered_set bypass_stop_check_set_; std::shared_ptr engine_svc_ = nullptr; diff --git a/engine/services/model_source_service.cc b/engine/services/model_source_service.cc index a7d9d5e6e..7fc0ef5b2 100644 --- a/engine/services/model_source_service.cc +++ b/engine/services/model_source_service.cc @@ -9,7 +9,6 @@ #include "utils/string_utils.h" #include "utils/url_parser.h" -namespace services { namespace hu = huggingface_utils; namespace { @@ -61,10 +60,13 @@ std::vector ParseJsonString(const std::string& json_str) { } // namespace -ModelSourceService::ModelSourceService() { +ModelSourceService::ModelSourceService( + std::shared_ptr db_service) + : db_service_(db_service) { sync_db_thread_ = std::thread(&ModelSourceService::SyncModelSource, this); running_ = true; } + ModelSourceService::~ModelSourceService() { running_ = false; if (sync_db_thread_.joinable()) { @@ -106,8 +108,7 @@ cpp::result ModelSourceService::AddModelSource( cpp::result ModelSourceService::RemoveModelSource( const std::string& model_source) { - cortex::db::Models model_db; - auto srcs = model_db.GetModelSources(); + auto srcs = db_service_->GetModelSources(); if (srcs.has_error()) { return cpp::fail(srcs.error()); } else { @@ -127,13 +128,13 @@ cpp::result ModelSourceService::RemoveModelSource( } if (r.pathParams.size() == 1) { - if (auto del_res = model_db.DeleteModelEntryWithOrg(model_source); + if (auto del_res = db_service_->DeleteModelEntryWithOrg(model_source); del_res.has_error()) { CTL_INF(del_res.error()); return cpp::fail(del_res.error()); } } else { - if (auto del_res = model_db.DeleteModelEntryWithRepo(model_source); + if (auto del_res = db_service_->DeleteModelEntryWithRepo(model_source); del_res.has_error()) { CTL_INF(del_res.error()); return cpp::fail(del_res.error()); @@ -145,8 +146,7 @@ cpp::result ModelSourceService::RemoveModelSource( cpp::result, std::string> ModelSourceService::GetModelSources() { - cortex::db::Models model_db; - return model_db.GetModelSources(); + return db_service_->GetModelSources(); } cpp::result ModelSourceService::AddHfOrg( @@ -156,10 +156,9 @@ cpp::result ModelSourceService::AddHfOrg( if (res.has_value()) { auto models = ParseJsonString(res.value()); // Get models from db - cortex::db::Models model_db; - auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + auto model_list_before = db_service_->GetModels(model_source) + .value_or(std::vector{}); std::unordered_set updated_model_list; // Add new models for (auto const& m : models) { @@ -179,7 +178,7 @@ cpp::result ModelSourceService::AddHfOrg( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); + if (auto del_res = db_service_->DeleteModelEntry(mid); del_res.has_error()) { CTL_INF(del_res.error()); } @@ -195,10 +194,9 @@ cpp::result ModelSourceService::AddHfRepo( const std::string& model_source, const std::string& author, const std::string& model_name) { // Get models from db - cortex::db::Models model_db; auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + db_service_->GetModels(model_source).value_or(std::vector{}); std::unordered_set updated_model_list; auto add_res = AddRepoSiblings(model_source, author, model_name); if (add_res.has_error()) { @@ -208,7 +206,8 @@ cpp::result ModelSourceService::AddHfRepo( } for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + if (auto del_res = db_service_->DeleteModelEntry(mid); + del_res.has_error()) { CTL_INF(del_res.error()); } } @@ -234,7 +233,6 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, for (const auto& sibling : repo_info->siblings) { if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { - cortex::db::Models model_db; std::string model_id = author + ":" + model_name + ":" + sibling.rfilename; cortex::db::ModelEntry e = { @@ -248,15 +246,15 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, .status = cortex::db::ModelStatus::Downloadable, .engine = "llama-cpp", .metadata = repo_info->metadata}; - if (!model_db.HasModel(model_id)) { - if (auto add_res = model_db.AddModelEntry(e); add_res.has_error()) { + if (!db_service_->HasModel(model_id)) { + if (auto add_res = db_service_->AddModelEntry(e); add_res.has_error()) { CTL_INF(add_res.error()); } } else { - if (auto m = model_db.GetModelInfo(model_id); + if (auto m = db_service_->GetModelInfo(model_id); m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { - if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + if (auto upd_res = db_service_->UpdateModelEntry(model_id, e); upd_res.has_error()) { CTL_INF(upd_res.error()); } @@ -276,10 +274,9 @@ cpp::result ModelSourceService::AddCortexsoOrg( if (res.has_value()) { auto models = ParseJsonString(res.value()); // Get models from db - cortex::db::Models model_db; - auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + auto model_list_before = db_service_->GetModels(model_source) + .value_or(std::vector{}); std::unordered_set updated_model_list; for (auto const& m : models) { CTL_INF(m.id); @@ -313,7 +310,7 @@ cpp::result ModelSourceService::AddCortexsoOrg( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); + if (auto del_res = db_service_->DeleteModelEntry(mid); del_res.has_error()) { CTL_INF(del_res.error()); } @@ -340,10 +337,9 @@ cpp::result ModelSourceService::AddCortexsoRepo( return cpp::fail(repo_info.error()); } // Get models from db - cortex::db::Models model_db; auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + db_service_->GetModels(model_source).value_or(std::vector{}); std::unordered_set updated_model_list; for (auto const& [branch, _] : branches.value()) { @@ -359,7 +355,8 @@ cpp::result ModelSourceService::AddCortexsoRepo( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + if (auto del_res = db_service_->DeleteModelEntry(mid); + del_res.has_error()) { CTL_INF(del_res.error()); } } @@ -397,7 +394,6 @@ ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, CTL_INF("Only support gguf file format! - branch: " << branch); return {}; } else { - cortex::db::Models model_db; std::string model_id = model_name + ":" + branch; cortex::db::ModelEntry e = {.model = model_id, .author_repo_id = author, @@ -409,16 +405,16 @@ ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, .status = cortex::db::ModelStatus::Downloadable, .engine = "llama-cpp", .metadata = metadata}; - if (!model_db.HasModel(model_id)) { + if (!db_service_->HasModel(model_id)) { CTL_INF("Adding model to db: " << model_name << ":" << branch); - if (auto res = model_db.AddModelEntry(e); + if (auto res = db_service_->AddModelEntry(e); res.has_error() || !res.value()) { CTL_DBG("Cannot add model to db: " << model_id); } } else { - if (auto m = model_db.GetModelInfo(model_id); + if (auto m = db_service_->GetModelInfo(model_id); m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { - if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + if (auto upd_res = db_service_->UpdateModelEntry(model_id, e); upd_res.has_error()) { CTL_INF(upd_res.error()); } @@ -444,8 +440,7 @@ void ModelSourceService::SyncModelSource() { CTL_DBG("Start to sync cortex.db"); start_time = current_time; - cortex::db::Models model_db; - auto res = model_db.GetModelSources(); + auto res = db_service_->GetModelSources(); if (res.has_error()) { CTL_INF(res.error()); } else { @@ -489,5 +484,3 @@ void ModelSourceService::SyncModelSource() { } } } - -} // namespace services \ No newline at end of file diff --git a/engine/services/model_source_service.h b/engine/services/model_source_service.h index aa0b37259..7227267d3 100644 --- a/engine/services/model_source_service.h +++ b/engine/services/model_source_service.h @@ -2,14 +2,14 @@ #include #include #include +#include "services/database_service.h" #include "utils/result.hpp" -namespace services { class ModelSourceService { public: - explicit ModelSourceService(); + explicit ModelSourceService(std::shared_ptr db_service); ~ModelSourceService(); - + cpp::result AddModelSource( const std::string& model_source); @@ -22,9 +22,9 @@ class ModelSourceService { cpp::result AddHfOrg(const std::string& model_source, const std::string& author); - cpp::result AddHfRepo( - const std::string& model_source, const std::string& author, - const std::string& model_name); + cpp::result AddHfRepo(const std::string& model_source, + const std::string& author, + const std::string& model_name); cpp::result, std::string> AddRepoSiblings( const std::string& model_source, const std::string& author, @@ -41,13 +41,12 @@ class ModelSourceService { AddCortexsoRepoBranch(const std::string& model_source, const std::string& author, const std::string& model_name, - const std::string& branch, - const std::string& metadata); + const std::string& branch, const std::string& metadata); void SyncModelSource(); private: + std::shared_ptr db_service_ = nullptr; std::thread sync_db_thread_; std::atomic running_; -}; -} // namespace services \ No newline at end of file +}; \ No newline at end of file From 22ff0a10ee3132f39fad1644a9ed3edb2c604691 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Tue, 31 Dec 2024 09:28:29 +0700 Subject: [PATCH 08/16] Feat/python engine (#1784) * chore: add document * feat: update engine interface * chore: add document * feat: update engine interface * Feat: init python engine * Fix: conflict * feat: add python engine implementation * Fix: CI build window * Fix: CI build window * feat: support download python model from cortexso * feat: add inference interface * feat: integrate to cortex cpp * fix: remove pythone engine load engine option * Feat: init environment interface * feat: move virtual environment inside model * Update CMakeLists.txt * Update CMakeLists.txt * fix: CI build * fix: move log of python to cortex logs folder * fix: unitest for remote engine because change location of template renderer * fix: CI build windows * fix: CI build windows * feat: add depends model.yml for python engine * fix: CI build * update set permission api * Fix: comment * Fix: remove unnecessary interface * Fix comment * Fix: comment review --------- Co-authored-by: James --- engine/CMakeLists.txt | 7 +- engine/cli/CMakeLists.txt | 5 +- engine/common/base.h | 7 +- engine/common/download_task.h | 13 +- engine/config/model_config.h | 341 ++++++- engine/controllers/models.cc | 30 + engine/controllers/server.cc | 50 + engine/controllers/server.h | 7 + engine/cortex-common/EngineI.h | 8 + .../extensions/python-engine/python_engine.cc | 860 ++++++++++++++++++ .../extensions/python-engine/python_engine.h | 166 ++++ .../extensions/remote-engine/remote_engine.h | 6 +- .../{remote-engine => }/template_renderer.cc | 2 +- .../{remote-engine => }/template_renderer.h | 2 +- engine/services/engine_service.cc | 19 + engine/services/engine_service.h | 1 + engine/services/inference_service.cc | 58 ++ engine/services/inference_service.h | 7 + engine/services/model_service.cc | 192 +++- engine/test/components/CMakeLists.txt | 2 +- engine/test/components/test_remote_engine.cc | 6 +- engine/utils/config_yaml_utils.h | 10 +- engine/utils/curl_utils.cc | 35 + engine/utils/curl_utils.h | 2 + engine/utils/engine_constants.h | 5 + engine/utils/file_manager_utils.cc | 1 + engine/utils/jinja_utils.h | 4 +- engine/utils/set_permission_utils.h | 76 ++ 28 files changed, 1882 insertions(+), 40 deletions(-) create mode 100644 engine/extensions/python-engine/python_engine.cc create mode 100644 engine/extensions/python-engine/python_engine.h rename engine/extensions/{remote-engine => }/template_renderer.cc (99%) rename engine/extensions/{remote-engine => }/template_renderer.h (97%) create mode 100644 engine/utils/set_permission_utils.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index e82e07aab..024f015a8 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -142,9 +142,14 @@ file(APPEND "${CMAKE_CURRENT_BINARY_DIR}/cortex_openapi.h" add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc + + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/template_renderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/python-engine/python_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/dylib_path_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc + ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index eb29460a7..4ca734d6a 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -85,7 +85,10 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/database_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc + + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/python-engine/python_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/template_renderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/config_yaml_utils.cc diff --git a/engine/common/base.h b/engine/common/base.h index 478cc7feb..b5de09059 100644 --- a/engine/common/base.h +++ b/engine/common/base.h @@ -5,7 +5,7 @@ using namespace drogon; class BaseModel { public: - virtual ~BaseModel() {} + virtual ~BaseModel() = default; // Model management virtual void LoadModel( @@ -27,7 +27,7 @@ class BaseModel { class BaseChatCompletion { public: - virtual ~BaseChatCompletion() {} + virtual ~BaseChatCompletion() = default; // General chat method virtual void ChatCompletion( @@ -37,7 +37,7 @@ class BaseChatCompletion { class BaseEmbedding { public: - virtual ~BaseEmbedding() {} + virtual ~BaseEmbedding() = default; // Implement embedding functionality specific to chat virtual void Embedding( @@ -46,3 +46,4 @@ class BaseEmbedding { // The derived class can also override other methods if needed }; + diff --git a/engine/common/download_task.h b/engine/common/download_task.h index 95e736394..53f1902c5 100644 --- a/engine/common/download_task.h +++ b/engine/common/download_task.h @@ -6,7 +6,14 @@ #include #include -enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex }; +enum class DownloadType { + Model, + Engine, + Miscellaneous, + CudaToolkit, + Cortex, + Environments +}; struct DownloadItem { @@ -48,6 +55,8 @@ inline std::string DownloadTypeToString(DownloadType type) { return "CudaToolkit"; case DownloadType::Cortex: return "Cortex"; + case DownloadType::Environments: + return "Environments"; default: return "Unknown"; } @@ -64,6 +73,8 @@ inline DownloadType DownloadTypeFromString(const std::string& str) { return DownloadType::CudaToolkit; } else if (str == "Cortex") { return DownloadType::Cortex; + } else if (str == "Environments") { + return DownloadType::Environments; } else { return DownloadType::Miscellaneous; } diff --git a/engine/config/model_config.h b/engine/config/model_config.h index a799adb27..d8ede92f7 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,8 +1,11 @@ #pragma once #include -#include +#include #include +#include + +#include #include #include #include @@ -482,4 +485,340 @@ struct ModelConfig { } }; +struct Endpoint { + std::string method; + std::string path; + std::string transform_request; + std::string transform_response; +}; + +struct PythonModelConfig { + // General Metadata + std::string id; + std::string model; + std::string name; + int version; + + // Inference Parameters + Endpoint load_model; + Endpoint destroy; + Endpoint inference; + Endpoint heath_check; + std::vector extra_endpoints; + + // Model Load Parameters + std::string port; + std::string script; + std::string log_path; + std::string log_level; + std::string environment; + std::vector command; // New command field + std::vector files; + std::vector depends; + std::string engine; + Json::Value extra_params; // Accept dynamic extra parameters + + // Method to convert C++ struct to YAML + void ToYaml(const std::string& filepath) const { + YAML::Emitter out; + out << YAML::BeginMap; + + out << YAML::Key << "id" << YAML::Value << id; + out << YAML::Key << "model" << YAML::Value << model; + out << YAML::Key << "name" << YAML::Value << name; + out << YAML::Key << "version" << YAML::Value << version; + + // Inference Parameters + out << YAML::Key << "load_model" << YAML::Value << YAML::BeginMap; + out << YAML::Key << "method" << YAML::Value << load_model.method; + out << YAML::Key << "path" << YAML::Value << load_model.path; + out << YAML::Key << "transform_request" << YAML::Value + << load_model.transform_request; + out << YAML::Key << "transform_response" << YAML::Value + << load_model.transform_response; + out << YAML::EndMap; + + out << YAML::Key << "destroy" << YAML::Value << YAML::BeginMap; + out << YAML::Key << "method" << YAML::Value << destroy.method; + out << YAML::Key << "path" << YAML::Value << destroy.path; + out << YAML::EndMap; + + out << YAML::Key << "inference" << YAML::Value << YAML::BeginMap; + out << YAML::Key << "method" << YAML::Value << inference.method; + out << YAML::Key << "path" << YAML::Value << inference.path; + out << YAML::EndMap; + + out << YAML::Key << "extra_endpoints" << YAML::Value << YAML::BeginSeq; + for (const auto& endpoint : extra_endpoints) { + out << YAML::BeginMap; + out << YAML::Key << "method" << YAML::Value << endpoint.method; + out << YAML::Key << "path" << YAML::Value << endpoint.path; + out << YAML::EndMap; + } + out << YAML::EndSeq; + + // Model Load Parameters + out << YAML::Key << "port" << YAML::Value << port; + out << YAML::Key << "script" << YAML::Value << script; + out << YAML::Key << "log_path" << YAML::Value << log_path; + out << YAML::Key << "log_level" << YAML::Value << log_level; + out << YAML::Key << "environment" << YAML::Value << environment; + + // Serialize command as YAML list + out << YAML::Key << "command" << YAML::Value << YAML::BeginSeq; + for (const auto& cmd : command) { + out << cmd; + } + out << YAML::EndSeq; + + // Serialize files as YAML list + out << YAML::Key << "files" << YAML::Value << YAML::BeginSeq; + for (const auto& file : files) { + out << file; + } + out << YAML::EndSeq; + + // Serialize command as YAML list + out << YAML::Key << "depends" << YAML::Value << YAML::BeginSeq; + for (const auto& depend : depends) { + out << depend; + } + out << YAML::EndSeq; + + out << YAML::Key << "engine" << YAML::Value << engine; + + // Serialize extra_params as YAML + out << YAML::Key << "extra_params" << YAML::Value << YAML::BeginMap; + for (Json::ValueConstIterator iter = extra_params.begin(); + iter != extra_params.end(); ++iter) { + out << YAML::Key << iter.key().asString() << YAML::Value + << iter->asString(); + } + out << YAML::EndMap; + + std::ofstream fout(filepath); + if (!fout.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filepath); + } + fout << out.c_str(); + } + + // Method to populate struct from YAML file + void ReadFromYaml(const std::string& filePath) { + YAML::Node config = YAML::LoadFile(filePath); + + if (config["id"]) + id = config["id"].as(); + if (config["model"]) + model = config["model"].as(); + if (config["name"]) + name = config["name"].as(); + if (config["version"]) + version = config["version"].as(); + + // Inference Parameters + + auto ip = config; + if (ip["load_model"]) { + load_model.method = ip["load_model"]["method"].as(); + load_model.path = ip["load_model"]["path"].as(); + load_model.transform_request = + ip["load_model"]["transform_request"].as(); + load_model.transform_response = + ip["load_model"]["transform_response"].as(); + } + if (ip["destroy"]) { + destroy.method = ip["destroy"]["method"].as(); + destroy.path = ip["destroy"]["path"].as(); + } + if (ip["inference"]) { + inference.method = ip["inference"]["method"].as(); + inference.path = ip["inference"]["path"].as(); + } + if (ip["extra_endpoints"] && ip["extra_endpoints"].IsSequence()) { + for (const auto& endpoint : ip["extra_endpoints"]) { + Endpoint e; + e.method = endpoint["method"].as(); + e.path = endpoint["path"].as(); + extra_endpoints.push_back(e); + } + } + + // Model Load Parameters + + auto mlp = config; + if (mlp["port"]) + port = mlp["port"].as(); + if (mlp["script"]) + script = mlp["script"].as(); + if (mlp["log_path"]) + log_path = mlp["log_path"].as(); + if (mlp["log_level"]) + log_level = mlp["log_level"].as(); + if (mlp["environment"]) + environment = mlp["environment"].as(); + if (mlp["engine"]) + engine = mlp["engine"].as(); + + if (mlp["command"] && mlp["command"].IsSequence()) { + for (const auto& cmd : mlp["command"]) { + command.push_back(cmd.as()); + } + } + + if (mlp["files"] && mlp["files"].IsSequence()) { + for (const auto& file : mlp["files"]) { + files.push_back(file.as()); + } + } + + if (mlp["depends"] && mlp["depends"].IsSequence()) { + for (const auto& depend : mlp["depends"]) { + depends.push_back(depend.as()); + } + } + + if (mlp["extra_params"]) { + for (YAML::const_iterator it = mlp["extra_params"].begin(); + it != mlp["extra_params"].end(); ++it) { + extra_params[it->first.as()] = + it->second.as(); + } + } + } + + // Method to convert the struct to JSON + Json::Value ToJson() const { + Json::Value root; + + root["id"] = id; + root["model"] = model; + root["name"] = name; + root["version"] = version; + + // Inference Parameters + root["load_model"]["method"] = load_model.method; + root["load_model"]["path"] = load_model.path; + root["load_model"]["transform_request"] = load_model.transform_request; + root["load_model"]["transform_response"] = load_model.transform_response; + + root["destroy"]["method"] = destroy.method; + root["destroy"]["path"] = destroy.path; + + root["inference"]["method"] = inference.method; + root["inference"]["path"] = inference.path; + + for (const auto& endpoint : extra_endpoints) { + Json::Value e; + e["method"] = endpoint.method; + e["path"] = endpoint.path; + root["extra_endpoints"].append(e); + } + + // Model Load Parameters + root["port"] = port; + root["log_path"] = log_path; + root["log_level"] = log_level; + root["environment"] = environment; + root["script"] = script; + + // Serialize command as JSON array + for (const auto& cmd : command) { + root["command"].append(cmd); + } + + for (const auto& file : files) { + root["files"].append(file); + } + + for (const auto& depend : depends) { + root["depends"].append(depend); + } + + root["engine"] = engine; + root["extra_params"] = extra_params; // Serialize the JSON value directly + + return root; + } + + // Method to populate struct from JSON + void FromJson(const Json::Value& root) { + + if (root.isMember("id")) + id = root["id"].asString(); + if (root.isMember("model")) + model = root["model"].asString(); + if (root.isMember("name")) + name = root["name"].asString(); + if (root.isMember("version")) + version = root["version"].asInt(); + + // Inference Parameters + + const Json::Value& ip = root; + if (ip.isMember("load_model")) { + load_model.method = ip["load_model"]["method"].asString(); + load_model.path = ip["load_model"]["path"].asString(); + load_model.transform_request = + ip["load_model"]["transform_request"].asString(); + load_model.transform_response = + ip["load_model"]["transform_response"].asString(); + } + if (ip.isMember("destroy")) { + destroy.method = ip["destroy"]["method"].asString(); + destroy.path = ip["destroy"]["path"].asString(); + } + if (ip.isMember("inference")) { + inference.method = ip["inference"]["method"].asString(); + inference.path = ip["inference"]["path"].asString(); + } + if (ip.isMember("extra_endpoints")) { + for (const auto& endpoint : ip["extra_endpoints"]) { + Endpoint e; + e.method = endpoint["method"].asString(); + e.path = endpoint["path"].asString(); + extra_endpoints.push_back(e); + } + } + + // Model Load Parameters + + const Json::Value& mlp = root; + if (mlp.isMember("port")) + port = mlp["port"].asString(); + if (mlp.isMember("log_path")) + log_path = mlp["log_path"].asString(); + if (mlp.isMember("log_level")) + log_level = mlp["log_level"].asString(); + if (mlp.isMember("environment")) + environment = mlp["environment"].asString(); + if (mlp.isMember("engine")) + engine = mlp["engine"].asString(); + if (mlp.isMember("script")) + script = mlp["script"].asString(); + + if (mlp.isMember("command")) { + for (const auto& cmd : mlp["command"]) { + command.push_back(cmd.asString()); + } + } + + if (mlp.isMember("files")) { + for (const auto& file : mlp["files"]) { + files.push_back(file.asString()); + } + } + + if (mlp.isMember("depends")) { + for (const auto& depend : mlp["depends"]) { + depends.push_back(depend.asString()); + } + } + + if (mlp.isMember("extra_params")) { + extra_params = mlp["extra_params"]; // Directly assign the JSON value + } + } +}; + } // namespace config diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 1a501287d..34c6504ac 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -210,6 +210,16 @@ void Models::ListModel( } data.append(std::move(obj)); yaml_handler.Reset(); + } else if (model_config.engine == kPythonEngine) { + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + Json::Value obj = python_model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + data.append(std::move(obj)); } else { config::RemoteModelConfig remote_model_config; remote_model_config.LoadFromYamlFile( @@ -280,6 +290,19 @@ void Models::GetModel(const HttpRequestPtr& req, auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret); resp->setStatusCode(drogon::k200OK); callback(resp); + } else if (model_config.engine == kPythonEngine) { + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + ret = python_model_config.ToJson(); + ret["id"] = python_model_config.model; + ret["object"] = "model"; + ret["result"] = "OK"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); } else { config::RemoteModelConfig remote_model_config; remote_model_config.LoadFromYamlFile( @@ -350,6 +373,13 @@ void Models::UpdateModel(const HttpRequestPtr& req, yaml_handler.WriteYamlFile(yaml_fp.string()); message = "Successfully update model ID '" + model_id + "': " + json_body.toStyledString(); + } else if (model_config.engine == kPythonEngine) { + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml(yaml_fp.string()); + python_model_config.FromJson(json_body); + python_model_config.ToYaml(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); } else { config::RemoteModelConfig remote_model_config; remote_model_config.LoadFromYamlFile(yaml_fp.string()); diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index d8e29eb1b..961798d2c 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -127,6 +127,56 @@ void server::FineTuning( LOG_TRACE << "Done fine-tuning"; } +void server::Inference(const HttpRequestPtr& req, + std::function&& callback) { + LOG_TRACE << "Start inference"; + auto q = std::make_shared(); + auto ir = inference_svc_->HandleInference(q, req->getJsonObject()); + LOG_DEBUG << "request: " << req->getJsonObject()->toStyledString(); + if (ir.has_error()) { + auto err = ir.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); + resp->setStatusCode( + static_cast(std::get<0>(err)["status_code"].asInt())); + callback(resp); + return; + } + LOG_TRACE << "Wait to inference"; + auto [status, res] = q->wait_and_pop(); + LOG_DEBUG << "response: " << res.toStyledString(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode( + static_cast(status["status_code"].asInt())); + callback(resp); + LOG_TRACE << "Done inference"; +} + +void server::RouteRequest( + const HttpRequestPtr& req, + std::function&& callback) { + + LOG_TRACE << "Start route request"; + auto q = std::make_shared(); + auto ir = inference_svc_->HandleRouteRequest(q, req->getJsonObject()); + LOG_DEBUG << "request: " << req->getJsonObject()->toStyledString(); + if (ir.has_error()) { + auto err = ir.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(std::get<1>(err)); + resp->setStatusCode( + static_cast(std::get<0>(err)["status_code"].asInt())); + callback(resp); + return; + } + LOG_TRACE << "Wait to route request"; + auto [status, res] = q->wait_and_pop(); + LOG_DEBUG << "response: " << res.toStyledString(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode( + static_cast(status["status_code"].asInt())); + callback(resp); + LOG_TRACE << "Done route request"; +} + void server::LoadModel(const HttpRequestPtr& req, std::function&& callback) { auto ir = inference_svc_->LoadModel(req->getJsonObject()); diff --git a/engine/controllers/server.h b/engine/controllers/server.h index ef8a32f5d..42214a641 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -46,8 +46,11 @@ class server : public drogon::HttpController, ADD_METHOD_TO(server::ChatCompletion, "/v1/chat/completions", Options, Post); ADD_METHOD_TO(server::FineTuning, "/v1/fine_tuning/job", Options, Post); ADD_METHOD_TO(server::Embedding, "/v1/embeddings", Options, Post); + ADD_METHOD_TO(server::Inference, "/v1/inference", Options, Post); + ADD_METHOD_TO(server::RouteRequest, "/v1/route/request", Options, Post); METHOD_LIST_END + void ChatCompletion( const HttpRequestPtr& req, std::function&& callback) override; @@ -69,6 +72,10 @@ class server : public drogon::HttpController, void FineTuning( const HttpRequestPtr& req, std::function&& callback) override; + void Inference(const HttpRequestPtr& req, + std::function&& callback); + void RouteRequest(const HttpRequestPtr& req, + std::function&& callback); private: void ProcessStreamRes(std::function cb, diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index b796ebaed..b2d290d24 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -59,6 +59,14 @@ class EngineI { const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0; + virtual Json::Value GetRemoteModels() = 0; + virtual void HandleRouteRequest( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void HandleInference( + std::shared_ptr json_body, + std::function&& callback) = 0; + // Stop inflight chat completion in stream mode virtual void StopInferencing(const std::string& model_id) = 0; }; diff --git a/engine/extensions/python-engine/python_engine.cc b/engine/extensions/python-engine/python_engine.cc new file mode 100644 index 000000000..ddf6784e8 --- /dev/null +++ b/engine/extensions/python-engine/python_engine.cc @@ -0,0 +1,860 @@ +#include "python_engine.h" +#include +#include +#include +#include +namespace python_engine { +constexpr const int k200OK = 200; +constexpr const int k400BadRequest = 400; +constexpr const int k409Conflict = 409; +constexpr const int k500InternalServerError = 500; +constexpr const int kFileLoggerOption = 0; + +static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, + std::string* data) { + data->append(ptr, size * nmemb); + return size * nmemb; +} + +PythonEngine::PythonEngine() {} + +PythonEngine::~PythonEngine() { + curl_global_cleanup(); +} + +config::PythonModelConfig* PythonEngine::GetModelConfig( + const std::string& model) { + std::shared_lock lock(models_mutex_); + auto it = models_.find(model); + if (it != models_.end()) { + return &it->second; + } + return nullptr; +} +std::string constructWindowsCommandLine(const std::vector& args) { + std::string cmdLine; + for (const auto& arg : args) { + // Simple escaping for Windows command line + std::string escapedArg = arg; + if (escapedArg.find(' ') != std::string::npos) { + // Wrap in quotes and escape existing quotes + for (char& c : escapedArg) { + if (c == '"') + c = '\\'; + } + escapedArg = "\"" + escapedArg + "\""; + } + cmdLine += escapedArg + " "; + } + return cmdLine; +} + +std::vector convertToArgv(const std::vector& args) { + std::vector argv; + for (const auto& arg : args) { + argv.push_back(const_cast(arg.c_str())); + } + argv.push_back(nullptr); + return argv; +} + +pid_t PythonEngine::SpawnProcess(const std::string& model, + const std::vector& command) { + try { +#ifdef _WIN32 + // Windows process creation + STARTUPINFOA si = {0}; + PROCESS_INFORMATION pi = {0}; + si.cb = sizeof(si); + + // Construct command line + std::string cmdLine = constructWindowsCommandLine(command); + + // Convert string to char* for Windows API + char commandBuffer[4096]; + strncpy_s(commandBuffer, cmdLine.c_str(), sizeof(commandBuffer)); + + if (!CreateProcessA(NULL, // lpApplicationName + commandBuffer, // lpCommandLine + NULL, // lpProcessAttributes + NULL, // lpThreadAttributes + FALSE, // bInheritHandles + 0, // dwCreationFlags + NULL, // lpEnvironment + NULL, // lpCurrentDirectory + &si, // lpStartupInfo + &pi // lpProcessInformation + )) { + throw std::runtime_error("Failed to create process on Windows"); + } + + // Store the process ID + pid_t pid = pi.dwProcessId; + processMap[model] = pid; + + // Close handles to avoid resource leaks + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); + + return pid; + +#elif __APPLE__ || __linux__ + // POSIX process creation + pid_t pid; + + // Convert command vector to char*[] + std::vector argv = convertToArgv(command); + // for (auto c : command) { + // std::cout << c << " " << std::endl; + // } + + // Use posix_spawn for cross-platform compatibility + int spawn_result = posix_spawn(&pid, // pid output + command[0].c_str(), // executable path + NULL, // file actions + NULL, // spawn attributes + argv.data(), // argument vector + NULL // environment (inherit) + ); + + if (spawn_result != 0) { + throw std::runtime_error("Failed to spawn process"); + } + + // Store the process ID + processMap[model] = pid; + return pid; + +#else +#error Unsupported platform +#endif + } catch (const std::exception& e) { + LOG_ERROR << "Process spawning error: " << e.what(); + return -1; + } +} +bool PythonEngine::TerminateModelProcess(const std::string& model) { + auto it = processMap.find(model); + if (it == processMap.end()) { + LOG_ERROR << "No process found for model: " << model + << ", removing from list running models."; + models_.erase(model); + return false; + } + +#ifdef _WIN32 + HANDLE hProcess = OpenProcess(PROCESS_TERMINATE, FALSE, it->second); + if (hProcess == NULL) { + LOG_ERROR << "Failed to open process"; + return false; + } + + bool terminated = TerminateProcess(hProcess, 0) == TRUE; + CloseHandle(hProcess); + + if (terminated) { + processMap.erase(it); + return true; + } + +#elif __APPLE__ || __linux__ + int result = kill(it->second, SIGTERM); + if (result == 0) { + processMap.erase(it); + return true; + } +#endif + + return false; +} +CurlResponse PythonEngine::MakeGetRequest(const std::string& model, + const std::string& path) { + auto config = models_[model]; + std::string full_url = "http://localhost:" + config.port + path; + CurlResponse response; + + auto result = curl_utils::SimpleRequest(full_url, RequestType::GET); + if (result.has_error()) { + response.error = true; + response.error_message = result.error(); + } else { + response.body = result.value(); + } + return response; +} +CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model, + const std::string& path) { + auto config = models_[model]; + std::string full_url = "http://localhost:" + config.port + path; + CurlResponse response; + + auto result = curl_utils::SimpleRequest(full_url, RequestType::DEL); + + if (result.has_error()) { + response.error = true; + response.error_message = result.error(); + } else { + response.body = result.value(); + } + + return response; +} + +CurlResponse PythonEngine::MakePostRequest(const std::string& model, + const std::string& path, + const std::string& body) { + auto config = models_[model]; + std::string full_url = "http://localhost:" + config.port + path; + + CurlResponse response; + auto result = curl_utils::SimpleRequest(full_url, RequestType::POST, body); + + if (result.has_error()) { + response.error = true; + response.error_message = result.error(); + } else { + response.body = result.value(); + } + return response; +} + +bool PythonEngine::LoadModelConfig(const std::string& model, + const std::string& yaml_path) { + try { + config::PythonModelConfig config; + config.ReadFromYaml(yaml_path); + std::unique_lock lock(models_mutex_); + models_[model] = config; + } catch (const std::exception& e) { + LOG_ERROR << "Failed to load model config: " << e.what(); + return false; + } + + return true; +} + +void PythonEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + + Json::Value response_json; + Json::Value model_array(Json::arrayValue); + + for (const auto& pair : models_) { + auto val = pair.second.ToJson(); + model_array.append(val); + } + + response_json["object"] = "list"; + response_json["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json)); +} + +void PythonEngine::LoadModel( + std::shared_ptr json_body, + std::function&& callback) { + // TODO: handle a case that can spawn process but the process spawn fail. + pid_t pid; + if (!json_body->isMember("model") || !json_body->isMember("model_path")) { + Json::Value error; + error["error"] = "Missing required fields: model or model_path"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + const std::string& model_path = (*json_body)["model_path"].asString(); + if (models_.find(model) != models_.end()) { + Json::Value error; + error["error"] = "Model already loaded!"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k409Conflict; + callback(std::move(status), std::move(error)); + return; + } + + if (!LoadModelConfig(model, model_path)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; + } + auto model_config = models_[model]; + auto model_folder_path = model_config.files[0]; + auto data_folder_path = + std::filesystem::path(model_folder_path) / std::filesystem::path("venv"); + try { +#ifdef _WIN32 + auto executable = std::filesystem::path(data_folder_path) / + std::filesystem::path("Scripts"); +#else + auto executable = + std::filesystem::path(data_folder_path) / std::filesystem::path("bin"); +#endif + + auto executable_str = + (executable / std::filesystem::path(model_config.command[0])).string(); + auto command = model_config.command; + command[0] = executable_str; + command.push_back((std::filesystem::path(model_folder_path) / + std::filesystem::path(model_config.script)) + .string()); + std::list args{"--port", + model_config.port, + "--log_path", + (file_manager_utils::GetCortexLogPath() / + std::filesystem::path(model_config.log_path)) + .string(), + "--log_level", + model_config.log_level}; + if (!model_config.extra_params.isNull() && + model_config.extra_params.isObject()) { + for (const auto& key : model_config.extra_params.getMemberNames()) { + const Json::Value& value = model_config.extra_params[key]; + + // Convert key to string with -- prefix + std::string param_key = "--" + key; + + // Handle different JSON value types + if (value.isString()) { + args.emplace_back(param_key); + args.emplace_back(value.asString()); + } else if (value.isInt()) { + args.emplace_back(param_key); + args.emplace_back(std::to_string(value.asInt())); + } else if (value.isDouble()) { + args.emplace_back(param_key); + args.emplace_back(std::to_string(value.asDouble())); + } else if (value.isBool()) { + // For boolean, only add the flag if true + if (value.asBool()) { + args.emplace_back(param_key); + } + } + } + } + + // Add the parsed arguments to the command + command.insert(command.end(), args.begin(), args.end()); + pid = SpawnProcess(model, command); + if (pid == -1) { + std::unique_lock lock(models_mutex_); + if (models_.find(model) != models_.end()) { + models_.erase(model); + } + + Json::Value error; + error["error"] = "Fail to spawn process with pid -1"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; + } + } catch (const std::exception& e) { + std::unique_lock lock(models_mutex_); + if (models_.find(model) != models_.end()) { + models_.erase(model); + } + + Json::Value error; + error["error"] = e.what(); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response; + response["status"] = + "Model loaded successfully with pid: " + std::to_string(pid); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +void PythonEngine::UnloadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + + { + std::unique_lock lock(models_mutex_); + if (TerminateModelProcess(model)) { + models_.erase(model); + } else { + Json::Value error; + error["error"] = "Fail to terminate process with id: " + + std::to_string(processMap[model]); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + } + + Json::Value response; + response["status"] = "Model unloaded successfully"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +void PythonEngine::HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) {} + +void PythonEngine::HandleInference( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model is required!"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + std::string method = "post"; + std::string path = "/inference"; + std::string transform_request = + (*json_body).get("transform_request", "").asString(); + std::string transform_response = + (*json_body).get("transform_response", "").asString(); + std::string model = (*json_body)["model"].asString(); + Json::Value body = (*json_body)["body"]; + + // Transform Request + std::string transformed_request; + if (!transform_request.empty()) { + + try { + // Validate JSON body + if (!body || body.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Render with error handling + try { + transformed_request = renderer_.Render(transform_request, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + transformed_request = body.toStyledString(); + } + } else { + transformed_request = body.toStyledString(); + } + + // End Transform request + + CurlResponse response; + if (method == "post") { + response = MakePostRequest(model, path, transformed_request); + } else if (method == "get") { + response = MakeGetRequest(model, path); + } else if (method == "delete") { + response = MakeDeleteRequest(model, path); + } else { + Json::Value error; + error["error"] = + "method not supported! Supported methods are: post, get, delete"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + if (!transform_response.empty()) { + // Transform Response + std::string response_str; + try { + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + // Render with error handling + try { + response_str = renderer_.Render(transform_response, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json_final)); + } else { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json)); + } +} +Json::Value PythonEngine::GetRemoteModels() { + return Json::Value(); +} +void PythonEngine::StopInferencing(const std::string& model_id) {} +void PythonEngine::HandleRouteRequest( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model") || !json_body->isMember("method") || + !json_body->isMember("path")) { + Json::Value error; + error["error"] = + "Missing required field: model, method and path are required!"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + std::string method = (*json_body)["method"].asString(); + std::string path = (*json_body)["path"].asString(); + std::string transform_request = + (*json_body).get("transform_request", "").asString(); + std::string transform_response = + (*json_body).get("transform_response", "").asString(); + std::string model = (*json_body)["model"].asString(); + Json::Value body = (*json_body)["body"]; + + // Transform Request + std::string transformed_request; + if (!transform_request.empty()) { + + try { + // Validate JSON body + if (!body || body.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Render with error handling + try { + transformed_request = renderer_.Render(transform_request, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + transformed_request = body.toStyledString(); + } + } else { + transformed_request = body.toStyledString(); + } + + // End Transform request + + CurlResponse response; + if (method == "post") { + response = MakePostRequest(model, path, transformed_request); + } else if (method == "get") { + response = MakeGetRequest(model, path); + } else if (method == "delete") { + response = MakeDeleteRequest(model, path); + } else { + Json::Value error; + error["error"] = + "method not supported! Supported methods are: post, get, delete"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + if (!transform_response.empty()) { + // Transform Response + std::string response_str; + try { + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + // Render with error handling + try { + response_str = renderer_.Render(transform_response, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json_final)); + } else { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json)); + } +} + +void PythonEngine::GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + auto model = json_body->get("model", "").asString(); + auto model_config = models_[model]; + auto health_endpoint = model_config.heath_check; + auto response_health = MakeGetRequest(model, health_endpoint.path); + + if (response_health.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response_health.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response; + response["model"] = model; + response["model_loaded"] = true; + response["model_data"] = model_config.ToJson(); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +// Implement remaining virtual functions +void PythonEngine::HandleEmbedding( + std::shared_ptr, + std::function&& callback) { + callback(Json::Value(), Json::Value()); +} + +bool PythonEngine::IsSupported(const std::string& f) { + if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" || + f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" || + f == "SetLogLevel") { + return true; + } + return false; +} + +bool PythonEngine::SetFileLogger(int max_log_lines, + const std::string& log_path) { + if (!async_file_logger_) { + async_file_logger_ = std::make_unique(); + } + + async_file_logger_->setFileName(log_path); + async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines + async_file_logger_->startLogging(); + trantor::Logger::setOutputFunction( + [&](const char* msg, const uint64_t len) { + if (async_file_logger_) + async_file_logger_->output_(msg, len); + }, + [&]() { + if (async_file_logger_) + async_file_logger_->flush(); + }); + freopen(log_path.c_str(), "w", stderr); + freopen(log_path.c_str(), "w", stdout); + return true; +} + +void PythonEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { + trantor::Logger::setLogLevel(log_level); +} + +void PythonEngine::Load(EngineLoadOption opts) { + // Develop register model here on loading engine +}; + +void PythonEngine::Unload(EngineUnloadOption opts) {}; + +// extern "C" { +// EngineI* get_engine() { +// return new PythonEngine(); +// } +// } +} // namespace python_engine \ No newline at end of file diff --git a/engine/extensions/python-engine/python_engine.h b/engine/extensions/python-engine/python_engine.h new file mode 100644 index 000000000..7b112f435 --- /dev/null +++ b/engine/extensions/python-engine/python_engine.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "config/model_config.h" +#include "cortex-common/EngineI.h" +#include "extensions/template_renderer.h" +#include "utils/file_logger.h" +#include "utils/file_manager_utils.h" + +#include "utils/curl_utils.h" +#ifdef _WIN32 +#include +#include +using pid_t = DWORD; +#elif __APPLE__ || __linux__ +#include +#include +#include +#include +#include +#endif +// Helper for CURL response +namespace python_engine { +struct StreamContext { + std::shared_ptr> callback; + std::string buffer; +}; + +static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + + // Skip empty lines + if (line.empty() || line == "\r") + continue; + + // Remove "data: " prefix if present + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } + + // Skip [DONE] message + std::cout << line << std::endl; + if (line == "data: [DONE]") { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), Json::Value()); + break; + } + + // Parse the JSON + Json::Value chunk_json; + chunk_json["data"] = line + "\n\n"; + Json::Reader reader; + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); + } + + return size * nmemb; +} + +struct CurlResponse { + std::string body; + bool error{false}; + std::string error_message; +}; + +class PythonEngine : public EngineI { + private: + // Model configuration + + // Thread-safe model config storage + mutable std::shared_mutex models_mutex_; + std::unordered_map models_; + extensions::TemplateRenderer renderer_; + std::unique_ptr async_file_logger_; + std::unordered_map processMap; + + // Helper functions + CurlResponse MakePostRequest(const std::string& model, + const std::string& path, + const std::string& body); + CurlResponse MakeGetRequest(const std::string& model, + const std::string& path); + CurlResponse MakeDeleteRequest(const std::string& model, + const std::string& path); + + // Process manager functions + pid_t SpawnProcess(const std::string& model, + const std::vector& command); + bool TerminateModelProcess(const std::string& model); + + // Internal model management + bool LoadModelConfig(const std::string& model, const std::string& yaml_path); + config::PythonModelConfig* GetModelConfig(const std::string& model); + + public: + PythonEngine(); + ~PythonEngine(); + + void Load(EngineLoadOption opts) override; + + void Unload(EngineUnloadOption opts) override; + + // Main interface implementations + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) override; + + void LoadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) override; + + // Other required virtual functions + void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) override; + bool IsSupported(const std::string& feature) override; + bool SetFileLogger(int max_log_lines, const std::string& log_path) override; + void SetLogLevel(trantor::Logger::LogLevel logLevel) override; + void HandleRouteRequest( + std::shared_ptr json_body, + std::function&& callback) override; + void HandleInference( + std::shared_ptr json_body, + std::function&& callback) override; + Json::Value GetRemoteModels() override; + void StopInferencing(const std::string& model_id) override; +}; +} // namespace python_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index d8dfbad61..6f08b5403 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -8,7 +8,7 @@ #include #include #include "cortex-common/remote_enginei.h" -#include "extensions/remote-engine/template_renderer.h" +#include "extensions/template_renderer.h" #include "utils/engine_constants.h" #include "utils/file_logger.h" // Helper for CURL response @@ -21,7 +21,7 @@ struct StreamContext { // Cache value for Anthropic std::string id; std::string model; - TemplateRenderer& renderer; + extensions::TemplateRenderer& renderer; std::string stream_template; }; struct CurlResponse { @@ -46,7 +46,7 @@ class RemoteEngine : public RemoteEngineI { // Thread-safe model config storage mutable std::shared_mutex models_mtx_; std::unordered_map models_; - TemplateRenderer renderer_; + extensions::TemplateRenderer renderer_; Json::Value metadata_; std::string chat_req_template_; std::string chat_res_template_; diff --git a/engine/extensions/remote-engine/template_renderer.cc b/engine/extensions/template_renderer.cc similarity index 99% rename from engine/extensions/remote-engine/template_renderer.cc rename to engine/extensions/template_renderer.cc index 15514d17c..32e7d72f5 100644 --- a/engine/extensions/remote-engine/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -7,7 +7,7 @@ #include #include #include "utils/logging_utils.h" -namespace remote_engine { +namespace extensions { TemplateRenderer::TemplateRenderer() { // Configure Inja environment env_.set_trim_blocks(true); diff --git a/engine/extensions/remote-engine/template_renderer.h b/engine/extensions/template_renderer.h similarity index 97% rename from engine/extensions/remote-engine/template_renderer.h rename to engine/extensions/template_renderer.h index f59e7cc93..7eccef2eb 100644 --- a/engine/extensions/remote-engine/template_renderer.h +++ b/engine/extensions/template_renderer.h @@ -14,7 +14,7 @@ #include #include // clang-format on -namespace remote_engine { +namespace extensions { class TemplateRenderer { public: TemplateRenderer(); diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 53a4bfa65..39e6e7961 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -3,10 +3,15 @@ #include #include #include + #include #include "algorithm" #include "database/engines.h" + +#include "extensions/python-engine/python_engine.h" + #include "extensions/remote-engine/remote_engine.h" + #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -183,6 +188,7 @@ cpp::result EngineService::UninstallEngineVariant( const std::string& engine, const std::optional version, const std::optional variant) { auto ne = NormalizeEngine(engine); + // TODO: handle uninstall remote engine // only delete a remote engine if no model are using it auto exist_engine = GetEngineByNameAndVariant(engine); @@ -715,6 +721,14 @@ cpp::result EngineService::LoadEngine( return {}; } + // Check for python engine + + if (engine_name == kPythonEngine) { + engines_[engine_name].engine = new python_engine::PythonEngine(); + CTL_INF("Loaded engine: " << engine_name); + return {}; + } + // Check for remote engine if (IsRemoteEngine(engine_name)) { auto exist_engine = GetEngineByNameAndVariant(engine_name); @@ -884,6 +898,7 @@ EngineService::GetEngineDirPath(const std::string& engine_name) { cpp::result EngineService::UnloadEngine( const std::string& engine) { auto ne = NormalizeEngine(engine); + std::lock_guard lock(engines_mutex_); if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); @@ -942,6 +957,10 @@ cpp::result EngineService::IsEngineReady( } // End hard code + // Check for python engine + if (engine == kPythonEngine) { + return true; + } auto os = hw_inf_.sys_inf->os; if (os == kMacOs && (ne == kOnnxRepo || ne == kTrtLlmRepo)) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index fcd3fdda9..a460582c6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -132,6 +132,7 @@ class EngineService : public EngineServiceI { cpp::result UpdateEngine( const std::string& engine); + cpp::result, std::string> GetEngines(); cpp::result GetEngineById(int id); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 9d8e9f4f8..3668fb6fe 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -112,6 +112,64 @@ cpp::result InferenceService::HandleEmbedding( return {}; } +cpp::result InferenceService::HandleInference( + std::shared_ptr q, std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto engine_result = engine_service_->GetLoadedEngine(engine_type); + if (engine_result.has_error()) { + Json::Value res; + Json::Value stt; + res["message"] = "Engine is not loaded yet"; + stt["status_code"] = drogon::k400BadRequest; + LOG_WARN << "Engine is not loaded yet"; + return cpp::fail(std::make_pair(stt, res)); + } + + auto cb = [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }; + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleInference(json_body, std::move(cb)); + } + return {}; +} + +cpp::result InferenceService::HandleRouteRequest( + std::shared_ptr q, std::shared_ptr json_body) { + std::string engine_type; + if (!HasFieldInReq(json_body, "engine")) { + engine_type = kLlamaRepo; + } else { + engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); + } + + auto engine_result = engine_service_->GetLoadedEngine(engine_type); + if (engine_result.has_error()) { + Json::Value res; + Json::Value stt; + res["message"] = "Engine is not loaded yet"; + stt["status_code"] = drogon::k400BadRequest; + LOG_WARN << "Engine is not loaded yet"; + return cpp::fail(std::make_pair(stt, res)); + } + + auto cb = [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }; + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleRouteRequest(json_body, std::move(cb)); + } + return {}; +} + InferResult InferenceService::LoadModel( std::shared_ptr json_body) { std::string engine_type; diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 75b07b1a3..f23be3f23 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -3,6 +3,7 @@ #include #include #include +#include "extensions/remote-engine/remote_engine.h" #include "services/engine_service.h" #include "services/model_service.h" #include "utils/result.hpp" @@ -41,6 +42,12 @@ class InferenceService { cpp::result HandleEmbedding( std::shared_ptr q, std::shared_ptr json_body); + cpp::result HandleInference( + std::shared_ptr q, std::shared_ptr json_body); + + cpp::result HandleRouteRequest( + std::shared_ptr q, std::shared_ptr json_body); + InferResult LoadModel(std::shared_ptr json_body); InferResult UnloadModel(const std::string& engine, diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 2d69e0f17..c7925360b 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -9,7 +10,10 @@ #include "config/yaml_config.h" #include "database/models.h" #include "hardware_service.h" +#include "utils/archive_utils.h" + #include "services/inference_service.h" + #include "utils/cli_selection_utils.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" @@ -17,6 +21,7 @@ #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" #include "utils/result.hpp" +#include "utils/set_permission_utils.h" #include "utils/string_utils.h" #include "utils/widechar_conv.h" @@ -79,8 +84,7 @@ void ParseGguf(DatabaseService& db_service, CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = db_service.GetModelInfo(ggufDownloadItem.id); - m.has_value()) { + if (auto m = db_service.GetModelInfo(ggufDownloadItem.id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; if (auto r = db_service.UpdateModelEntry(ggufDownloadItem.id, upd_m); @@ -99,7 +103,7 @@ cpp::result GetDownloadTask( .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}, }; - auto result = curl_utils::SimpleGetJson(url.ToFullPath()); + auto result = curl_utils::SimpleGetJsonRecursive(url.ToFullPath()); if (result.has_error()) { return cpp::fail("Model " + modelId + " not found"); } @@ -110,6 +114,7 @@ cpp::result GetDownloadTask( file_manager_utils::CreateDirectoryRecursively(model_container_path.string()); for (const auto& value : result.value()) { + // std::cout << "value object: " << value.toStyledString() << std::endl; auto path = value["path"].asString(); if (path == ".gitattributes" || path == ".gitignore" || path == "README.md") { @@ -121,6 +126,9 @@ cpp::result GetDownloadTask( .pathParams = {"cortexso", modelId, "resolve", branch, path}}; auto local_path = model_container_path / path; + if (!std::filesystem::exists(local_path.parent_path())) { + std::filesystem::create_directories(local_path.parent_path()); + } download_items.push_back( DownloadItem{.id = path, .downloadUrl = download_url.ToFullPath(), @@ -466,7 +474,8 @@ cpp::result ModelService::HandleUrl( model_size = model_size + item.bytes.value_or(0); } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(*db_service_, gguf_download_item, author, std::nullopt, model_size); + ParseGguf(*db_service_, gguf_download_item, author, std::nullopt, + model_size); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished); @@ -528,15 +537,79 @@ ModelService::DownloadModelFromCortexsoAsync( config::YamlHandler yaml_handler; yaml_handler.ModelConfigFromFile(model_yml_item->localPath.string()); auto mc = yaml_handler.GetModelConfig(); - mc.model = unique_model_id; + if (mc.engine == kPythonEngine) { // process for Python engine + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml(model_yml_item->localPath.string()); + python_model_config.files.push_back( + model_yml_item->localPath.parent_path().string()); + python_model_config.ToYaml(model_yml_item->localPath.string()); + // unzip venv.zip + auto model_folder = model_yml_item->localPath.parent_path(); + auto venv_path = model_folder / std::filesystem::path("venv"); + if (!std::filesystem::exists(venv_path)) { + std::filesystem::create_directories(venv_path); + } + auto venv_zip = model_folder / std::filesystem::path("venv.zip"); + if (std::filesystem::exists(venv_zip)) { + if (archive_utils::ExtractArchive(venv_zip.string(), + venv_path.string())) { + std::filesystem::remove_all(venv_zip); + CTL_INF("Successfully extract venv.zip"); + // If extract success create pyvenv.cfg + std::ofstream pyvenv_cfg(venv_path / + std::filesystem::path("pyvenv.cfg")); +#ifdef _WIN32 + pyvenv_cfg << "home = " + << (venv_path / std::filesystem::path("Scripts")).string() + << std::endl; + pyvenv_cfg << "executable = " + << (venv_path / std::filesystem::path("Scripts") / + std::filesystem::path("python.exe")) + .string() + << std::endl; - uint64_t model_size = 0; - for (const auto& item : finishedTask.items) { - model_size = model_size + item.bytes.value_or(0); +#else + pyvenv_cfg << "home = " + << (venv_path / std::filesystem::path("bin/")).string() + << std::endl; + pyvenv_cfg + << "executable = " + << (venv_path / std::filesystem::path("bin/python")).string() + << std::endl; +#endif + + // Close the file + pyvenv_cfg.close(); + // Add executable permission to python + +#ifdef _WIN32 + set_permission_utils::SetExecutePermissionsRecursive( + venv_path / std::filesystem::path("Scripts")); +#else + set_permission_utils::SetExecutePermissionsRecursive( + venv_path / std::filesystem::path("bin")); +#endif + + } else { + CTL_ERR("Failed to extract venv.zip"); + }; + + } else { + CTL_ERR( + "venv.zip not found in model folder: " << model_folder.string()); + } + + } else { + mc.model = unique_model_id; + + uint64_t model_size = 0; + for (const auto& item : finishedTask.items) { + model_size = model_size + item.bytes.value_or(0); + } + mc.size = model_size; + yaml_handler.UpdateModelConfig(mc); + yaml_handler.WriteYamlFile(model_yml_item->localPath.string()); } - mc.size = model_size; - yaml_handler.UpdateModelConfig(mc); - yaml_handler.WriteYamlFile(model_yml_item->localPath.string()); auto rel = file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); @@ -583,7 +656,8 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [this, branch, model_id](const DownloadTask& finishedTask) { + auto on_finished = [this, branch, + model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -754,18 +828,75 @@ cpp::result ModelService::StartModel( constexpr const int kDefautlContextLength = 8192; int max_model_context_length = kDefautlContextLength; Json::Value json_data; - // Currently we don't support download vision models, so we need to bypass check - if (!bypass_model_check) { - auto model_entry = db_service_->GetModelInfo(model_handle); - if (model_entry.has_error()) { - CTL_WRN("Error: " + model_entry.error()); - return cpp::fail(model_entry.error()); - } - yaml_handler.ModelConfigFromFile( + auto model_entry = db_service_->GetModelInfo(model_handle); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return cpp::fail(model_entry.error()); + } + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + + // Check if Python model first + if (mc.engine == kPythonEngine) { + + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml( + fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) .string()); - auto mc = yaml_handler.GetModelConfig(); + // Start all depends model + auto depends = python_model_config.depends; + for (auto& depend : depends) { + Json::Value temp; + auto res = StartModel(depend, temp, false); + if (res.has_error()) { + CTL_WRN("Error: " + res.error()); + for (auto& depend : depends) { + if (depend != model_handle) { + StopModel(depend); + } + } + return cpp::fail("Model failed to start dependency '" + depend + + "' : " + res.error()); + } + } + + json_data["model"] = model_handle; + json_data["model_path"] = + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string(); + json_data["engine"] = mc.engine; + assert(!!inference_svc_); + // Check if python engine + + auto ir = + inference_svc_->LoadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + + if (status == drogon::k200OK) { + return StartModelResult{.success = true, .warning = ""}; + } else if (status == drogon::k409Conflict) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return StartModelResult{.success = true, .warning = ""}; + } else { + // only report to user the error + for (auto& depend : depends) { + + StopModel(depend); + } + } + CTL_ERR("Model failed to start with status code: " << status); + return cpp::fail("Model failed to start: " + data["message"].asString()); + } + + // Currently we don't support download vision models, so we need to bypass check + if (!bypass_model_check) { // Running remote model if (engine_svc_->IsRemoteEngine(mc.engine)) { @@ -856,6 +987,8 @@ cpp::result ModelService::StartModel( } assert(!!inference_svc_); + // Check if python engine + auto ir = inference_svc_->LoadModel(std::make_shared(json_data)); auto status = std::get<0>(ir)["status_code"].asInt(); @@ -917,6 +1050,23 @@ cpp::result ModelService::StopModel( if (bypass_check) { engine_name = kLlamaEngine; } + + // Update for python engine + if (engine_name == kPythonEngine) { + auto model_entry = db_service_->GetModelInfo(model_handle); + config::PythonModelConfig python_model_config; + python_model_config.ReadFromYaml( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + // Stop all depends model + auto depends = python_model_config.depends; + for (auto& depend : depends) { + StopModel(depend); + } + } + + // assert(inference_svc_); auto ir = inference_svc_->UnloadModel(engine_name, model_handle); auto status = std::get<0>(ir)["status_code"].asInt(); diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 0df46cfc2..6ca836158 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -16,7 +16,7 @@ add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/file_manager_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/curl_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/system_info_utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../extensions/remote-engine/template_renderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../extensions/template_renderer.cc ) find_package(Drogon CONFIG REQUIRED) diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index bfac76f49..5f1b85044 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -1,4 +1,4 @@ -#include "extensions/remote-engine/template_renderer.h" +#include "extensions/template_renderer.h" #include "gtest/gtest.h" #include "utils/json_helper.h" @@ -42,7 +42,7 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { auto data = json_helper::ParseJsonString(message_with_system); - remote_engine::TemplateRenderer rdr; + extensions::TemplateRenderer rdr; auto res = rdr.Render(tpl, data); auto res_json = json_helper::ParseJsonString(res); @@ -69,7 +69,7 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { auto data = json_helper::ParseJsonString(message_without_system); - remote_engine::TemplateRenderer rdr; + extensions::TemplateRenderer rdr; auto res = rdr.Render(tpl, data); auto res_json = json_helper::ParseJsonString(res); diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index f9925ea86..73eaf3084 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -2,8 +2,13 @@ #include #include +#include "utils/engine_constants.h" +#include "utils/logging_utils.h" + #include + #include "utils/engine_constants.h" + #include "utils/result.hpp" namespace config_yaml_utils { @@ -20,7 +25,9 @@ const std::vector kDefaultEnabledOrigins{ "http://localhost:39281", "http://127.0.0.1:39281", "http://0.0.0.0:39281"}; constexpr const auto kDefaultNoProxy = "example.com,::1,localhost,127.0.0.1"; const std::vector kDefaultSupportedEngines{ - kLlamaEngine, kOnnxEngine, kTrtLlmEngine}; + kLlamaEngine, kOnnxEngine, kTrtLlmEngine, kPythonEngine}; + + struct CortexConfig { std::string logFolderPath; @@ -58,6 +65,7 @@ struct CortexConfig { bool verifyPeerSsl; bool verifyHostSsl; + std::string sslCertPath; std::string sslKeyPath; std::vector supportedEngines; diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc index 71f263a6a..be82b5cfa 100644 --- a/engine/utils/curl_utils.cc +++ b/engine/utils/curl_utils.cc @@ -260,6 +260,41 @@ cpp::result SimpleGetJson(const std::string& url, return root; } +cpp::result SimpleGetJsonRecursive( + const std::string& url, const int timeout) { + auto result = SimpleGetJson(url, timeout); + if (result.has_error()) { + return result; + } + auto root = result.value(); + + if (root.isArray()) { + for (const auto& value : root) { + if (value["type"].asString() == "directory") { + auto temp = SimpleGetJsonRecursive(url + "/" + value["path"].asString(), + timeout); + if (!temp.has_error()) { + if (temp.value().isArray()) { + for (const auto& item : temp.value()) { + root.append(item); + } + } else { + root.append(temp.value()); + } + } + } + } + for (Json::ArrayIndex i = 0; i < root.size();) { + if (root[i].isMember("type") && root[i]["type"] == "directory") { + root.removeIndex(i, nullptr); + } else { + ++i; + } + } + } + return root; +} + cpp::result SimplePostJson(const std::string& url, const std::string& body) { auto result = SimpleRequest(url, RequestType::POST, body); diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 64b5fc339..f33b7e8e5 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -34,6 +34,8 @@ cpp::result ReadRemoteYaml(const std::string& url); */ cpp::result SimpleGetJson(const std::string& url, const int timeout = -1); +cpp::result SimpleGetJsonRecursive(const std::string& url, + const int timeout = -1); cpp::result SimplePostJson( const std::string& url, const std::string& body = ""); diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index dcdf6a443..9392ede35 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -3,12 +3,17 @@ constexpr const auto kOnnxEngine = "onnxruntime"; constexpr const auto kLlamaEngine = "llama-cpp"; constexpr const auto kTrtLlmEngine = "tensorrt-llm"; + +constexpr const auto kPythonEngine = "python-engine"; + constexpr const auto kOpenAiEngine = "openai"; constexpr const auto kAnthropicEngine = "anthropic"; + constexpr const auto kRemote = "remote"; constexpr const auto kLocal = "local"; + constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; constexpr const auto kTrtLlmRepo = "cortex.tensorrt-llm"; diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index aee65020c..a83c93efa 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -185,6 +185,7 @@ config_yaml_utils::CortexConfig GetDefaultConfig() { .noProxy = config_yaml_utils::kDefaultNoProxy, .verifyPeerSsl = true, .verifyHostSsl = true, + .sslCertPath = "", .sslKeyPath = "", .supportedEngines = config_yaml_utils::kDefaultSupportedEngines, diff --git a/engine/utils/jinja_utils.h b/engine/utils/jinja_utils.h index f614f4745..12244599f 100644 --- a/engine/utils/jinja_utils.h +++ b/engine/utils/jinja_utils.h @@ -3,7 +3,7 @@ #include #include -#include "extensions/remote-engine/template_renderer.h" +#include "extensions/template_renderer.h" #include "utils/chat-template.hpp" #include "utils/result.hpp" @@ -14,7 +14,7 @@ inline cpp::result RenderTemplate( bool add_generation_prompt = true) { try { auto converted_json = - remote_engine::TemplateRenderer().ConvertJsonValue(data); + extensions::TemplateRenderer().ConvertJsonValue(data); minja::chat_template chat_tmpl(tmpl, add_bos_token ? bos_token : "", add_eos_token ? eos_token : ""); diff --git a/engine/utils/set_permission_utils.h b/engine/utils/set_permission_utils.h new file mode 100644 index 000000000..c1c08ce8f --- /dev/null +++ b/engine/utils/set_permission_utils.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif +#include "utils/logging_utils.h" +namespace set_permission_utils { +// Cross-platform method to set execute permission for a single file +[[nodiscard]] inline bool SetExecutePermission(const std::filesystem::path& filePath, + bool ownerOnly = false) noexcept { + try { + std::filesystem::perms current_perms = std::filesystem::status(filePath).permissions(); + std::filesystem::perms new_perms; + + if (ownerOnly) { + new_perms = current_perms | std::filesystem::perms::owner_exec; + // Remove group and others execute permissions + new_perms &= ~(std::filesystem::perms::group_exec | std::filesystem::perms::others_exec); + } else { + new_perms = current_perms | std::filesystem::perms::owner_exec | + std::filesystem::perms::group_exec | + std::filesystem::perms::others_exec; + } + + std::filesystem::permissions(filePath, new_perms, + std::filesystem::perm_options::replace); + return true; + } catch (const std::filesystem::filesystem_error& e) { + CTL_ERR("Permission error for file " << filePath.string() + << ": " << e.what()); + return false; + } catch (const std::exception& e) { + CTL_ERR("Unexpected error for file " << filePath.string() + << ": " << e.what()); + return false; + } +} + +[[nodiscard]] inline std::vector SetExecutePermissionsRecursive( + const std::filesystem::path& directoryPath, + bool ownerOnly = false, + bool skipDirectories = true) { + std::vector modifiedFiles; + modifiedFiles.reserve(100); // Reserve space to prevent frequent reallocations + + try { + const auto options = std::filesystem::directory_options::skip_permission_denied | + std::filesystem::directory_options::follow_directory_symlink; + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(directoryPath, options)) { + if (skipDirectories && entry.is_directory()) { + continue; + } + + if (entry.is_regular_file()) { + if (SetExecutePermission(entry.path(), ownerOnly)) { + modifiedFiles.push_back(entry.path()); + } + } + } + } catch (const std::filesystem::filesystem_error& e) { + CTL_ERR("Filesystem error: " << e.what()); + } + + return modifiedFiles; +} + +} // namespace set_permission_utils \ No newline at end of file From c893b4bbd16c2e16a1b4969d6879a46c1989f380 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Tue, 31 Dec 2024 11:31:02 +0700 Subject: [PATCH 09/16] Feat/stream request python engine (#1829) * chore: add document * feat: update engine interface * chore: add document * feat: update engine interface * Feat: init python engine * Fix: conflict * feat: add python engine implementation * Fix: CI build window * Fix: CI build window * feat: support download python model from cortexso * feat: add inference interface * feat: integrate to cortex cpp * fix: remove pythone engine load engine option * Feat: init environment interface * feat: move virtual environment inside model * Update CMakeLists.txt * Update CMakeLists.txt * fix: CI build * fix: move log of python to cortex logs folder * fix: unitest for remote engine because change location of template renderer * fix: CI build windows * fix: CI build windows * feat: add depends model.yml for python engine * fix: CI build * stream response * update set permission api * Fix: comment * Feat: stream response * fix: run concurrent request with stream mode * Fix: remove unnecessary interface * Fix comment * Fix: comment review * fix comment * fix comment --------- Co-authored-by: James --- engine/controllers/server.cc | 57 +++++++++---- .../extensions/python-engine/python_engine.cc | 81 +++++++++++++++++-- .../extensions/python-engine/python_engine.h | 17 ++-- 3 files changed, 127 insertions(+), 28 deletions(-) diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 961798d2c..83eaddb4e 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -129,6 +129,9 @@ void server::FineTuning( void server::Inference(const HttpRequestPtr& req, std::function&& callback) { + + auto json_body = req->getJsonObject(); + LOG_TRACE << "Start inference"; auto q = std::make_shared(); auto ir = inference_svc_->HandleInference(q, req->getJsonObject()); @@ -141,20 +144,34 @@ void server::Inference(const HttpRequestPtr& req, callback(resp); return; } + + bool is_stream = + (*json_body).get("stream", false).asBool() || + (*json_body).get("body", Json::Value()).get("stream", false).asBool(); + LOG_TRACE << "Wait to inference"; - auto [status, res] = q->wait_and_pop(); - LOG_DEBUG << "response: " << res.toStyledString(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode( - static_cast(status["status_code"].asInt())); - callback(resp); - LOG_TRACE << "Done inference"; + if (is_stream) { + auto model_id = (*json_body).get("model", "invalid_model").asString(); + auto engine_type = [this, &json_body]() -> std::string { + if (!inference_svc_->HasFieldInReq(json_body, "engine")) { + return kLlamaRepo; + } else { + return (*(json_body)).get("engine", kLlamaRepo).asString(); + } + }(); + ProcessStreamRes(callback, q, engine_type, model_id); + } else { + ProcessNonStreamRes(callback, *q); + LOG_TRACE << "Done inference"; + } } void server::RouteRequest( const HttpRequestPtr& req, std::function&& callback) { + auto json_body = req->getJsonObject(); + LOG_TRACE << "Start route request"; auto q = std::make_shared(); auto ir = inference_svc_->HandleRouteRequest(q, req->getJsonObject()); @@ -167,14 +184,26 @@ void server::RouteRequest( callback(resp); return; } + auto is_stream = + (*json_body).get("stream", false).asBool() || + (*json_body).get("body", Json::Value()).get("stream", false).asBool(); LOG_TRACE << "Wait to route request"; - auto [status, res] = q->wait_and_pop(); - LOG_DEBUG << "response: " << res.toStyledString(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode( - static_cast(status["status_code"].asInt())); - callback(resp); - LOG_TRACE << "Done route request"; + if (is_stream) { + + auto model_id = (*json_body).get("model", "invalid_model").asString(); + auto engine_type = [this, &json_body]() -> std::string { + if (!inference_svc_->HasFieldInReq(json_body, "engine")) { + return kLlamaRepo; + } else { + return (*(json_body)).get("engine", kLlamaRepo).asString(); + } + }(); + ProcessStreamRes(callback, q, engine_type, model_id); + } else { + ProcessNonStreamRes(callback, *q); + LOG_TRACE << "Done route request"; + } + } void server::LoadModel(const HttpRequestPtr& req, diff --git a/engine/extensions/python-engine/python_engine.cc b/engine/extensions/python-engine/python_engine.cc index ddf6784e8..9be369bcf 100644 --- a/engine/extensions/python-engine/python_engine.cc +++ b/engine/extensions/python-engine/python_engine.cc @@ -16,7 +16,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, return size * nmemb; } -PythonEngine::PythonEngine() {} +PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {} + PythonEngine::~PythonEngine() { curl_global_cleanup(); @@ -169,7 +170,7 @@ bool PythonEngine::TerminateModelProcess(const std::string& model) { } CurlResponse PythonEngine::MakeGetRequest(const std::string& model, const std::string& path) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -184,7 +185,7 @@ CurlResponse PythonEngine::MakeGetRequest(const std::string& model, } CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model, const std::string& path) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -203,7 +204,7 @@ CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model, CurlResponse PythonEngine::MakePostRequest(const std::string& model, const std::string& path, const std::string& body) { - auto config = models_[model]; + auto const& config = models_[model]; std::string full_url = "http://localhost:" + config.port + path; CurlResponse response; @@ -450,6 +451,63 @@ void PythonEngine::HandleChatCompletion( std::shared_ptr json_body, std::function&& callback) {} +CurlResponse PythonEngine::MakeStreamPostRequest( + const std::string& model, const std::string& path, const std::string& body, + const std::function& callback) { + auto const& config = models_[model]; + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = "http://localhost:" + config.port + path; + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, "Accept: text/event-stream"); + headers = curl_slist_append(headers, "Cache-Control: no-cache"); + headers = curl_slist_append(headers, "Connection: keep-alive"); + + StreamContext context{ + std::make_shared>( + callback), + ""}; + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context); + curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = 500; + + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + + void PythonEngine::HandleInference( std::shared_ptr json_body, std::function&& callback) { @@ -485,7 +543,8 @@ void PythonEngine::HandleInference( // Render with error handling try { - transformed_request = renderer_.Render(transform_request, *json_body); + transformed_request = renderer_.Render(transform_request, body); + } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + std::string(e.what())); @@ -504,7 +563,17 @@ void PythonEngine::HandleInference( CurlResponse response; if (method == "post") { - response = MakePostRequest(model, path, transformed_request); + if (body.isMember("stream") && body["stream"].asBool()) { + q_.runTaskInQueue( + [this, model, path, transformed_request, cb = std::move(callback)] { + MakeStreamPostRequest(model, path, transformed_request, cb); + }); + + return; + } else { + response = MakePostRequest(model, path, transformed_request); + } + } else if (method == "get") { response = MakeGetRequest(model, path); } else if (method == "delete") { diff --git a/engine/extensions/python-engine/python_engine.h b/engine/extensions/python-engine/python_engine.h index 7b112f435..979ba1fd8 100644 --- a/engine/extensions/python-engine/python_engine.h +++ b/engine/extensions/python-engine/python_engine.h @@ -8,6 +8,8 @@ #include #include #include "config/model_config.h" +#include "trantor/utils/ConcurrentTaskQueue.h" + #include "cortex-common/EngineI.h" #include "extensions/template_renderer.h" #include "utils/file_logger.h" @@ -44,19 +46,12 @@ static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); + LOG_DEBUG << "line: "< async_file_logger_; std::unordered_map processMap; + trantor::ConcurrentTaskQueue q_; + // Helper functions CurlResponse MakePostRequest(const std::string& model, @@ -108,6 +105,10 @@ class PythonEngine : public EngineI { const std::string& path); CurlResponse MakeDeleteRequest(const std::string& model, const std::string& path); + CurlResponse MakeStreamPostRequest( + const std::string& model, const std::string& path, + const std::string& body, + const std::function& callback); // Process manager functions pid_t SpawnProcess(const std::string& model, From a77cd969130fc8f1af5dcab93c067fbffaa1cca3 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Sat, 4 Jan 2025 12:39:55 +0700 Subject: [PATCH 10/16] fix: download recursive (#1838) * fix: download recursive * fix: handle model not loaded * fix: set permission for all venv folder * format code --- .../extensions/python-engine/python_engine.cc | 24 +++++++++++++++++++ engine/services/model_service.cc | 7 +----- engine/utils/curl_utils.cc | 8 +++++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/engine/extensions/python-engine/python_engine.cc b/engine/extensions/python-engine/python_engine.cc index 9be369bcf..a62f6526b 100644 --- a/engine/extensions/python-engine/python_engine.cc +++ b/engine/extensions/python-engine/python_engine.cc @@ -531,6 +531,18 @@ void PythonEngine::HandleInference( std::string model = (*json_body)["model"].asString(); Json::Value body = (*json_body)["body"]; + if (models_.find(model) == models_.end()) { + Json::Value error; + error["error"] = "Model '" + model + "' is not loaded!"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + // Transform Request std::string transformed_request; if (!transform_request.empty()) { @@ -699,6 +711,18 @@ void PythonEngine::HandleRouteRequest( std::string model = (*json_body)["model"].asString(); Json::Value body = (*json_body)["body"]; + if (models_.find(model) == models_.end()) { + Json::Value error; + error["error"] = "Model '" + model + "' is not loaded!"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + // Transform Request std::string transformed_request; if (!transform_request.empty()) { diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index c7925360b..e1d436058 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -582,13 +582,8 @@ ModelService::DownloadModelFromCortexsoAsync( pyvenv_cfg.close(); // Add executable permission to python -#ifdef _WIN32 - set_permission_utils::SetExecutePermissionsRecursive( - venv_path / std::filesystem::path("Scripts")); -#else set_permission_utils::SetExecutePermissionsRecursive( - venv_path / std::filesystem::path("bin")); -#endif + venv_path ); } else { CTL_ERR("Failed to extract venv.zip"); diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc index be82b5cfa..d5945e8c8 100644 --- a/engine/utils/curl_utils.cc +++ b/engine/utils/curl_utils.cc @@ -271,8 +271,12 @@ cpp::result SimpleGetJsonRecursive( if (root.isArray()) { for (const auto& value : root) { if (value["type"].asString() == "directory") { - auto temp = SimpleGetJsonRecursive(url + "/" + value["path"].asString(), - timeout); + auto temp = SimpleGetJsonRecursive( + url + "/" + + std::filesystem::path(value["path"].asString()) + .filename() + .string(), + timeout); if (!temp.has_error()) { if (temp.value().isArray()) { for (const auto& item : temp.value()) { From c4b370f178d0234e9045ce4809e4d3fe3d907016 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:17:31 +0700 Subject: [PATCH 11/16] fix: cortex stop child process (#1841) * fix: download recursive * fix: handle model not loaded * fix: set permission for all venv folder * format code * fix: cortex stop will stop all running models * Fix: comment --- engine/controllers/process_manager.cc | 9 ++++++--- engine/controllers/process_manager.h | 7 +++++++ engine/extensions/python-engine/python_engine.cc | 8 +++++--- engine/main.cc | 2 +- engine/services/model_service.cc | 7 +------ 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/engine/controllers/process_manager.cc b/engine/controllers/process_manager.cc index 8373f08fe..9d1604754 100644 --- a/engine/controllers/process_manager.cc +++ b/engine/controllers/process_manager.cc @@ -1,13 +1,16 @@ #include "process_manager.h" -#include "utils/cortex_utils.h" - #include #include +#include "json/json.h" +#include "utils/cortex_utils.h" void ProcessManager::destroy( const HttpRequestPtr& req, std::function&& callback) { - + auto loaded_engines = engine_service_->GetSupportedEngineNames(); + for (const auto& engine : loaded_engines.value()) { + engine_service_->UnloadEngine(engine); + } app().quit(); Json::Value ret; ret["message"] = "Program is exitting, goodbye!"; diff --git a/engine/controllers/process_manager.h b/engine/controllers/process_manager.h index bded7b103..449e66d21 100644 --- a/engine/controllers/process_manager.h +++ b/engine/controllers/process_manager.h @@ -2,6 +2,7 @@ #include #include +#include "services/engine_service.h" using namespace drogon; @@ -13,4 +14,10 @@ class ProcessManager : public drogon::HttpController { void destroy(const HttpRequestPtr& req, std::function&& callback); + + ProcessManager(std::shared_ptr engine_service) + : engine_service_(engine_service) {} + + private: + std::shared_ptr engine_service_; }; diff --git a/engine/extensions/python-engine/python_engine.cc b/engine/extensions/python-engine/python_engine.cc index a62f6526b..f9557d70b 100644 --- a/engine/extensions/python-engine/python_engine.cc +++ b/engine/extensions/python-engine/python_engine.cc @@ -18,7 +18,6 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {} - PythonEngine::~PythonEngine() { curl_global_cleanup(); } @@ -507,7 +506,6 @@ CurlResponse PythonEngine::MakeStreamPostRequest( return response; } - void PythonEngine::HandleInference( std::shared_ptr json_body, std::function&& callback) { @@ -943,7 +941,11 @@ void PythonEngine::Load(EngineLoadOption opts) { // Develop register model here on loading engine }; -void PythonEngine::Unload(EngineUnloadOption opts) {}; +void PythonEngine::Unload(EngineUnloadOption opts) { + for (const auto& pair : models_) { + TerminateModelProcess(pair.first); + } +}; // extern "C" { // EngineI* get_engine() { diff --git a/engine/main.cc b/engine/main.cc index 77f51c7fa..59ec49873 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -182,7 +182,7 @@ void RunServer(std::optional host, std::optional port, auto model_ctl = std::make_shared(db_service, model_service, engine_service, model_src_svc); auto event_ctl = std::make_shared(event_queue_ptr); - auto pm_ctl = std::make_shared(); + auto pm_ctl = std::make_shared(engine_service); auto hw_ctl = std::make_shared(engine_service, hw_service); auto server_ctl = std::make_shared(inference_svc, engine_service); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index e1d436058..74767a9b2 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -567,7 +567,6 @@ ModelService::DownloadModelFromCortexsoAsync( std::filesystem::path("python.exe")) .string() << std::endl; - #else pyvenv_cfg << "home = " << (venv_path / std::filesystem::path("bin/")).string() @@ -577,14 +576,10 @@ ModelService::DownloadModelFromCortexsoAsync( << (venv_path / std::filesystem::path("bin/python")).string() << std::endl; #endif - // Close the file pyvenv_cfg.close(); // Add executable permission to python - - set_permission_utils::SetExecutePermissionsRecursive( - venv_path ); - + set_permission_utils::SetExecutePermissionsRecursive(venv_path); } else { CTL_ERR("Failed to extract venv.zip"); }; From 7918935c696748a0aa184a91f878c021bd936043 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Tue, 7 Jan 2025 08:43:26 +0700 Subject: [PATCH 12/16] Feat/python package ci (#1792) * feat: add ci for python package * feat: add ci for python package * feat: add ci for python package * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test package unix * test package mac arm * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * test package using miniconda * Add upload artifact linux * feat: add codesign for macos * Test CI window * Test CI window * Test CI window * Test CI window * test CI windows * test CI windows * test CI windows include hidden file * test CI macos include hidden file * test CI macos include hidden file * test CI macos include hidden file * chore: add package pipeline for 4 os * chore: add package pipeline for 4 os change compression level * chore: add package pipeline for 4 os optimize linux size * chore: add package pipeline for 4 os optimize linux size and fix windows * chore: add package pipeline fix windows * Feat python package codesign (#1780) * feat: add codesign for macos * feat: add codesign for macos * fix: notary python zipped folder --------- Co-authored-by: Hien To * Update python-package.yml * Update python-package.yml * test: package fish speech * test: package fish speech * test: rerun windows * feat: package env for ichigo-wrapper * feat: package env for ichigo-wrapper * feat: package env for ichigo-wrapper * feat: package env for whispervq * feat: package env for fish-speech * feat: package env for fish-speech mac * Fix: increase timeout for macos notarize * Update python-package.yml * Fix: upload venv to hf instead of github releas * Update: test run new CI for package python * Update: test run new CI for package python * Update: test run new CI for package python * Update: test run new CI for package python windows * Update: test run new CI for package python windows * Finished: venv package * feat: init CI for upload python script to huggingface * Finished CI for upload python scripts --------- Co-authored-by: Hien To Co-authored-by: hiento09 <136591877+hiento09@users.noreply.github.com> --- .github/workflows/python-script-package.yml | 72 +++++ .github/workflows/python-venv-package.yml | 275 ++++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 .github/workflows/python-script-package.yml create mode 100644 .github/workflows/python-venv-package.yml diff --git a/.github/workflows/python-script-package.yml b/.github/workflows/python-script-package.yml new file mode 100644 index 000000000..5ea65be9c --- /dev/null +++ b/.github/workflows/python-script-package.yml @@ -0,0 +1,72 @@ +name: Build and Package Python Code + +on: + workflow_dispatch: + inputs: + model_dir: + description: "Path to model directory in github repo" + required: true + repo_name: + description: "name of repo to be checked out" + required: true + branch_name: + description: "name of branch to be checked out" + required: true + default: main + hf_repo: + description: "name of huggingface repo to be pushed" + required: true + hf_prefix_branch: + description: "prefix of hf branch" + required: false + +env: + MODEL_DIR: ${{ inputs.model_dir }} + REPO_NAME: ${{ inputs.repo_name}} + BRANCH_NAME: ${{ inputs.branch_name }} + HF_REPO: ${{ inputs.hf_repo }} + HF_PREFIX_BRANCH: ${{ inputs.hf_prefix_branch }} + +jobs: + build-and-test: + runs-on: ${{ matrix.runs-on }} + timeout-minutes: 3600 + strategy: + fail-fast: false + matrix: + include: + - os: "linux" + name: "amd64" + runs-on: "ubuntu-20-04-cuda-12-0" + - os: "mac" + name: "amd64" + runs-on: "macos-selfhosted-12" + - os: "mac" + name: "arm64" + runs-on: "macos-selfhosted-12-arm64" + - os: "windows" + name: "amd64" + runs-on: "windows-cuda-12-0" + steps: + - name: Clone + id: checkout + uses: actions/checkout@v3 + with: + submodules: recursive + repository: ${{env.REPO_NAME}} + ref: ${{env.BRANCH_NAME}} + - name: use python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install hf-transfer huggingface_hub + + - name: Upload Artifact + run: | + huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN_WRITE }} --add-to-git-credential + cd ${{env.MODEL_DIR}} && huggingface-cli upload ${{env.HF_REPO}} . . --revision ${{env.HF_PREFIX_BRANCH}}-${{ matrix.os }}-${{ matrix.name }} + huggingface-cli logout \ No newline at end of file diff --git a/.github/workflows/python-venv-package.yml b/.github/workflows/python-venv-package.yml new file mode 100644 index 000000000..8bed4eb97 --- /dev/null +++ b/.github/workflows/python-venv-package.yml @@ -0,0 +1,275 @@ +name: Build and Package Python Virtual Environment + +on: + workflow_dispatch: + inputs: + model_dir: + description: "Path to model directory in github repo" + required: true + model_name: + description: "name of model to be release" + required: true + repo_name: + description: "name of repo to be checked out" + required: true + branch_name: + description: "name of branch to be checked out" + required: true + default: main + hf_repo: + description: "name of huggingface repo to be pushed" + required: true + hf_prefix_branch: + description: "prefix of hf branch" + required: false + + + +env: + MODEL_DIR: ${{ inputs.model_dir }} + MODEL_NAME: ${{ inputs.model_name }} + REPO_NAME: ${{ inputs.repo_name }} + BRANCH_NAME: ${{ inputs.branch_name }} + HF_REPO: ${{ inputs.hf_repo }} + HF_PREFIX_BRANCH: ${{ inputs.hf_prefix_branch }} + +jobs: + build-and-test: + runs-on: ${{ matrix.runs-on }} + timeout-minutes: 3600 + strategy: + fail-fast: false + matrix: + include: + - os: "linux" + name: "amd64" + runs-on: "ubuntu-20-04-cuda-12-0" + - os: "mac" + name: "amd64" + runs-on: "macos-selfhosted-12" + - os: "mac" + name: "arm64" + runs-on: "macos-selfhosted-12-arm64" + - os: "windows" + name: "amd64" + runs-on: "windows-cuda-12-0" + steps: + - name: Clone + id: checkout + uses: actions/checkout@v3 + with: + submodules: recursive + repository: ${{env.REPO_NAME}} + ref: ${{env.BRANCH_NAME}} + - uses: conda-incubator/setup-miniconda@v3 + if: runner.os != 'windows' + with: + auto-update-conda: true + python-version: 3.11 + - name: use python + if : runner.os == 'windows' + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Get Cer for code signing + if: runner.os == 'macOS' + run: base64 -d <<< "$CODE_SIGN_P12_BASE64" > /tmp/codesign.p12 + shell: bash + env: + CODE_SIGN_P12_BASE64: ${{ secrets.CODE_SIGN_P12_BASE64 }} + + - uses: apple-actions/import-codesign-certs@v2 + continue-on-error: true + if: runner.os == 'macOS' + with: + p12-file-base64: ${{ secrets.CODE_SIGN_P12_BASE64 }} + p12-password: ${{ secrets.CODE_SIGN_P12_PASSWORD }} + + - name: Get Cer for code signing + if: runner.os == 'macOS' + run: base64 -d <<< "$NOTARIZE_P8_BASE64" > /tmp/notary-key.p8 + shell: bash + env: + NOTARIZE_P8_BASE64: ${{ secrets.NOTARIZE_P8_BASE64 }} + + - name: Install dependencies Windows + if: runner.os == 'windows' + shell: pwsh + run: | + python3 -m pip install fastapi + python3 -m pip freeze | % { python3 -m pip uninstall -y $_ } + python3 -m pip install --upgrade pip + python3 -m pip install -I -r ${{env.MODEL_DIR}}/requirements.cuda.txt + python3 -m pip install python-dotenv + - name: Install dependencies Linux + if: runner.os == 'linux' + run: | + conda create -y -n ${{env.MODEL_NAME}} python=3.11 + source $HOME/miniconda3/bin/activate base + conda init + conda activate ${{env.MODEL_NAME}} + python -m pip install fastapi + python -m pip freeze | xargs python -m pip uninstall -y + python -m pip install --upgrade pip + python -m pip install -r ${{env.MODEL_DIR}}/requirements.cuda.txt + python -m pip install python-dotenv + - name: Install dependencies Mac + if: runner.os == 'macOS' + run: | + conda create -y -n ${{env.MODEL_NAME}} python=3.11 + source $HOME/miniconda3/bin/activate base + conda init + conda activate ${{env.MODEL_NAME}} + python -m pip install fastapi + python -m pip freeze | xargs python -m pip uninstall -y + python -m pip install --upgrade pip + python -m pip install -r ${{env.MODEL_DIR}}/requirements.txt + python -m pip install python-dotenv + + - name: prepare python package windows + if : runner.os == 'windows' + shell: pwsh + run: | + $pythonPath = where.exe python + echo "Python path (where.exe): $pythonPath" + $pythonFolder = Split-Path -Path "$pythonPath" -Parent + echo "PYTHON_FOLDER=$pythonFolder" >> $env:GITHUB_ENV + copy "$pythonFolder\python*.*" "$pythonFolder\Scripts\" + + - name: prepare python package macos + if : runner.os == 'macOs' + run: | + source $HOME/miniconda3/bin/activate base + conda init + conda activate ${{env.MODEL_NAME}} + PYTHON_PATH=$(which python) + echo $PYTHON_PATH + PYTHON_FOLDER=$(dirname $(dirname "$PYTHON_PATH")) + echo "PYTHON_FOLDER=$PYTHON_FOLDER" >> $GITHUB_ENV + echo "github end PYTHON_FOLDER: ${{env.PYTHON_FOLDER}}" + - name: prepare python package linux + if : runner.os == 'linux' + run: | + source $HOME/miniconda3/bin/activate base + conda init + conda activate ${{env.MODEL_NAME}} + PYTHON_PATH=$(which python) + echo $PYTHON_PATH + PYTHON_FOLDER=$(dirname $(dirname "$PYTHON_PATH")) + rm -rf $PYTHON_FOLDER/lib/python3.1 + echo "PYTHON_FOLDER=$PYTHON_FOLDER" >> $GITHUB_ENV + echo "github end PYTHON_FOLDER: ${{env.PYTHON_FOLDER}}" + + - name: create plist file + if: runner.os == 'macOS' + run: | + cat << EOF > /tmp/entitlements.plist + + + + + + com.apple.security.cs.allow-jit + + com.apple.security.cs.allow-unsigned-executable-memory + + + + com.apple.security.app-sandbox + + com.apple.security.network.client + + com.apple.security.network.server + + com.apple.security.device.audio-input + + com.apple.security.device.microphone + + com.apple.security.device.camera + + com.apple.security.files.user-selected.read-write + + com.apple.security.cs.disable-library-validation + + com.apple.security.cs.allow-dyld-environment-variables + + com.apple.security.cs.allow-executable-memory + + + + EOF + + - name: Notary macOS Binary + if: runner.os == 'macOS' + run: | + codesign --force --entitlements="/tmp/entitlements.plist" -s "${{ secrets.DEVELOPER_ID }}" --options=runtime ${{env.PYTHON_FOLDER}}/bin/python + codesign --force --entitlements="/tmp/entitlements.plist" -s "${{ secrets.DEVELOPER_ID }}" --options=runtime ${{env.PYTHON_FOLDER}}/bin/python3 + # Code sign all .so files and .dylib files + + find ${{env.PYTHON_FOLDER}} -type f \( -name "*.so" -o -name "*.dylib" \) -exec codesign --force --entitlements="/tmp/entitlements.plist" -s "${{ secrets.DEVELOPER_ID }}" --options=runtime {} \; + + curl -sSfL https://raw.githubusercontent.com/anchore/quill/main/install.sh | sudo sh -s -- -b /usr/local/bin + # Notarize the binary + quill notarize ${{env.PYTHON_FOLDER}}/bin/python + quill notarize ${{env.PYTHON_FOLDER}}/bin/python3 + find ${{env.PYTHON_FOLDER}} -type f \( -name "*.so" -o -name "*.dylib" \) -exec quill notarize {} \; + env: + QUILL_NOTARY_KEY_ID: ${{ secrets.NOTARY_KEY_ID }} + QUILL_NOTARY_ISSUER: ${{ secrets.NOTARY_ISSUER }} + QUILL_NOTARY_KEY: "/tmp/notary-key.p8" + + + - name: Upload Artifact MacOS + if : runner.os == 'macOS' + run: | + brew install zip + cd ${{env.PYTHON_FOLDER}} && zip -r venv.zip * + conda create -y -n hf-upload python=3.11 + source $HOME/miniconda3/bin/activate base + conda init + conda activate hf-upload + python -m pip install hf-transfer huggingface_hub + huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN_WRITE }} --add-to-git-credential + huggingface-cli upload ${{env.HF_REPO}} venv.zip --revision ${{env.HF_PREFIX_BRANCH}}-${{ matrix.os }}-${{ matrix.name }} + rm -rf venv.zip + huggingface-cli logout + + - name: Upload Artifact Linux + if : runner.os == 'linux' + run: | + sudo apt-get install -y zip + cd ${{env.PYTHON_FOLDER}} && zip -r venv.zip * + conda create -y -n hf-upload python=3.11 + source $HOME/miniconda3/bin/activate base + conda init + conda activate hf-upload + python -m pip install hf-transfer huggingface_hub + huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN_WRITE }} --add-to-git-credential + huggingface-cli upload ${{env.HF_REPO}} venv.zip --revision ${{env.HF_PREFIX_BRANCH}}-${{ matrix.os }}-${{ matrix.name }} + rm -rf venv.zip + huggingface-cli logout + + + - name: Upload Artifact Windows + if : runner.os == 'windows' + shell: pwsh + run: | + Compress-Archive -Path ${{env.PYTHON_FOLDER}}/* -DestinationPath venv.zip + python -m pip install hf-transfer huggingface_hub + huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN_WRITE }} --add-to-git-credential + huggingface-cli upload ${{env.HF_REPO}} venv.zip --revision ${{env.HF_PREFIX_BRANCH}}-${{ matrix.os }}-${{ matrix.name }} + rm venv.zip + huggingface-cli logout + + + - name: Post Upload windows + if : runner.os == 'windows' + run: | + rm ${{env.PYTHON_FOLDER}}/Scripts/python*.* + + - name: Remove Keychain + continue-on-error: true + if: always() && runner.os == 'macOS' + run: | + security delete-keychain signing_temp.keychain From 9b96b47a47d6b1de7881e7e4ef7ea99e8bcf4403 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 7 Jan 2025 09:07:11 +0700 Subject: [PATCH 13/16] fix: use after std::move --- engine/services/engine_service.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 39e6e7961..c6b107af3 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -763,8 +763,8 @@ cpp::result EngineService::LoadEngine( // register deps if (!(getenv("ENGINE_PATH"))) { std::vector paths{}; - paths.push_back(std::move(cuda_path)); - paths.push_back(std::move(engine_dir_path)); + paths.push_back(cuda_path); + paths.push_back(engine_dir_path); CTL_DBG("Registering dylib for " << ne << " with " << std::to_string(paths.size()) << " paths."); @@ -830,8 +830,8 @@ void EngineService::RegisterEngineLibPath() { // register deps std::vector paths{}; - paths.push_back(std::move(cuda_path)); - paths.push_back(std::move(engine_dir_path)); + paths.push_back(cuda_path); + paths.push_back(engine_dir_path); CTL_DBG("Registering dylib for " << ne << " with " << std::to_string(paths.size()) << " paths."); From 5825412b362198c3dd7083b7dc33717cc596dbba Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 8 Jan 2025 09:19:08 +0700 Subject: [PATCH 14/16] fix: build linux arm (#1806) * fix: build linux arm * feat: cicd arm64 * fix: build linux arm * feat: cicd arm64 * fix: e2e test linux arm64 * fix: select linux arm64 * chore: e2e tests * fix: ci linux arm * fix: ci correct artifact name * fix: linux installer arm64 --------- Co-authored-by: Hien To Co-authored-by: sangjanai Co-authored-by: Service Account --- .github/workflows/beta-build.yml | 26 +++++- .github/workflows/cortex-cpp-quality-gate.yml | 36 +++++++++ .github/workflows/nightly-build.yml | 21 ++++- .github/workflows/stable-build.yml | 18 ++++- ...linux-x64.yml => template-build-linux.yml} | 80 ++++++++++++------- engine/Makefile | 3 +- engine/e2e-test/test_api_engine.py | 4 +- engine/e2e-test/test_api_engine_update.py | 2 +- engine/e2e-test/test_api_model.py | 1 + engine/e2e-test/test_cli_engine_install.py | 4 +- .../test_cli_engine_install_nightly.py | 1 + engine/e2e-test/test_cli_engine_uninstall.py | 2 +- engine/services/download_service.cc | 10 +-- engine/templates/linux/control | 2 +- engine/templates/linux/create_deb.sh | 4 +- engine/templates/linux/create_deb_local.sh | 3 +- engine/templates/linux/install.sh | 29 ++++--- .../components/test_engine_matcher_utils.cc | 13 +++ engine/utils/curl_utils.cc | 77 +++++++++++------- engine/utils/curl_utils.h | 7 +- engine/utils/engine_matcher_utils.h | 5 ++ 21 files changed, 255 insertions(+), 93 deletions(-) rename .github/workflows/{template-build-linux-x64.yml => template-build-linux.yml} (66%) diff --git a/.github/workflows/beta-build.yml b/.github/workflows/beta-build.yml index c5c09dcb5..bdc277231 100644 --- a/.github/workflows/beta-build.yml +++ b/.github/workflows/beta-build.yml @@ -67,7 +67,7 @@ jobs: cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} build-linux-x64: - uses: ./.github/workflows/template-build-linux-x64.yml + uses: ./.github/workflows/template-build-linux.yml secrets: inherit needs: [get-update-version, create-draft-release, get-cortex-llamacpp-latest-version] with: @@ -79,6 +79,22 @@ jobs: channel: beta upload_url: ${{ needs.create-draft-release.outputs.upload_url }} cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: amd64 + + build-linux-arm64: + uses: ./.github/workflows/template-build-linux.yml + secrets: inherit + needs: [get-update-version, create-draft-release, get-cortex-llamacpp-latest-version] + with: + ref: ${{ github.ref }} + public_provider: github + new_version: ${{ needs.get-update-version.outputs.new_version }} + runs-on: ubuntu-2004-arm64 + cmake-flags: "-DCORTEX_VARIANT=beta -DCORTEX_CPP_VERSION='v${{ needs.get-update-version.outputs.new_version }}' -DCMAKE_TOOLCHAIN_FILE=/home/runner/actions-runner/_work/cortex.cpp/cortex.cpp/engine/vcpkg/scripts/buildsystems/vcpkg.cmake" + channel: beta + upload_url: ${{ needs.create-draft-release.outputs.upload_url }} + cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: arm64 build-docker-x64: uses: ./.github/workflows/template-build-docker-x64.yml @@ -111,7 +127,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} noti-discord: - needs: [get-update-version, create-draft-release, build-macos, build-windows-x64, build-linux-x64, update_release] + needs: [get-update-version, create-draft-release, build-macos, build-windows-x64, build-linux-x64, build-linux-arm64, update_release] runs-on: ubuntu-latest permissions: contents: write @@ -136,9 +152,13 @@ jobs: - Network Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-mac-universal-network-installer.pkg - Local Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-mac-universal-local-installer.pkg - Binary: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-mac-universal.tar.gz - - Linux Deb: + - Linux amd64 Deb: - Network Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-amd64-network-installer.deb - Local Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-amd64-local-installer.deb - Binary: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-amd64.tar.gz + - Linux amd64 Deb: + - Network Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-arm64-network-installer.deb + - Local Installer: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-arm64-local-installer.deb + - Binary: https://github.com/janhq/cortex.cpp/releases/download/v${{ env.VERSION }}/cortex-${{ env.VERSION }}-linux-arm64.tar.gz - Docker: menloltd/cortex:beta-${{ env.VERSION }} - Github Release: https://github.com/janhq/cortex.cpp/releases/tag/v${{ env.VERSION }} \ No newline at end of file diff --git a/.github/workflows/cortex-cpp-quality-gate.yml b/.github/workflows/cortex-cpp-quality-gate.yml index 8a76e4669..fd98930a1 100644 --- a/.github/workflows/cortex-cpp-quality-gate.yml +++ b/.github/workflows/cortex-cpp-quality-gate.yml @@ -20,6 +20,12 @@ jobs: fail-fast: false matrix: include: + - os: "linux" + name: "arm64" + runs-on: "ubuntu-2004-arm64" + cmake-flags: "-DCORTEX_CPP_VERSION=${{github.event.pull_request.head.sha}} -DCMAKE_BUILD_TEST=ON -DCMAKE_TOOLCHAIN_FILE=vcpkg/scripts/buildsystems/vcpkg.cmake" + build-deps-cmake-flags: "" + ccache-dir: "" - os: "linux" name: "amd64" runs-on: "ubuntu-20-04-cuda-12-0" @@ -52,6 +58,7 @@ jobs: submodules: recursive - name: use python + continue-on-error: true uses: actions/setup-python@v5 with: python-version: "3.10" @@ -90,15 +97,44 @@ jobs: AWS_DEFAULT_REGION: "${{ secrets.MINIO_REGION }}" - name: Configure vcpkg + if: runner.os != 'Linux' + run: | + cd engine + make configure-vcpkg + + - name: Configure vcpkg linux amd64 + if: runner.os != 'Linux' + run: | + cd engine + make configure-vcpkg + + - name: Configure vcpkg linux arm64 + if: runner.os == 'Linux' run: | cd engine + # Set env if arch is arm64 + if [ "${{ matrix.name }}" == "arm64" ]; then + sudo apt install ninja-build pkg-config -y + export VCPKG_FORCE_SYSTEM_BINARIES=1 + fi make configure-vcpkg - name: Build + if: runner.os != 'Linux' run: | cd engine make build CMAKE_EXTRA_FLAGS="${{ matrix.cmake-flags }}" BUILD_DEPS_CMAKE_EXTRA_FLAGS="${{ matrix.build-deps-cmake-flags }}" + - name: Build + if: runner.os == 'Linux' + run: | + cd engine + if [ "${{ matrix.name }}" == "arm64" ]; then + export VCPKG_FORCE_SYSTEM_BINARIES=1 + fi + make build CMAKE_EXTRA_FLAGS="${{ matrix.cmake-flags }}" BUILD_DEPS_CMAKE_EXTRA_FLAGS="${{ matrix.build-deps-cmake-flags }}" + + - name: Run setup config run: | cd engine diff --git a/.github/workflows/nightly-build.yml b/.github/workflows/nightly-build.yml index 9a31ef5ff..1f076dc97 100644 --- a/.github/workflows/nightly-build.yml +++ b/.github/workflows/nightly-build.yml @@ -74,7 +74,7 @@ jobs: cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} build-linux-x64: - uses: ./.github/workflows/template-build-linux-x64.yml + uses: ./.github/workflows/template-build-linux.yml secrets: inherit needs: [get-update-version, set-public-provider, get-cortex-llamacpp-latest-version] with: @@ -85,11 +85,26 @@ jobs: cmake-flags: "-DCORTEX_VARIANT=nightly -DCORTEX_CPP_VERSION='v${{ needs.get-update-version.outputs.new_version }}' -DCMAKE_TOOLCHAIN_FILE=/home/runner/actions-runner/_work/cortex.cpp/cortex.cpp/engine/vcpkg/scripts/buildsystems/vcpkg.cmake" channel: nightly cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: amd64 + + build-linux-arm64: + uses: ./.github/workflows/template-build-linux.yml + secrets: inherit + needs: [get-update-version, set-public-provider, get-cortex-llamacpp-latest-version] + with: + ref: ${{ needs.set-public-provider.outputs.ref }} + public_provider: ${{ needs.set-public-provider.outputs.public_provider }} + new_version: ${{ needs.get-update-version.outputs.new_version }} + runs-on: ubuntu-2004-arm64 + cmake-flags: "-DCORTEX_VARIANT=nightly -DCORTEX_CPP_VERSION='v${{ needs.get-update-version.outputs.new_version }}' -DCMAKE_TOOLCHAIN_FILE=/home/runner/actions-runner/_work/cortex.cpp/cortex.cpp/engine/vcpkg/scripts/buildsystems/vcpkg.cmake" + channel: nightly + cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: arm64 update-latest-version: runs-on: ubuntu-latest if: needs.set-public-provider.outputs.public_provider == 'aws-s3' - needs: [get-update-version, set-public-provider, build-linux-x64, build-macos, build-windows-x64, get-cortex-llamacpp-latest-version] + needs: [get-update-version, set-public-provider, build-linux-x64, build-linux-arm64, build-macos, build-windows-x64, get-cortex-llamacpp-latest-version] steps: - name: Update latest version id: update-latest-version @@ -100,9 +115,11 @@ jobs: aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/mac-universal-cortex-nightly.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/mac-arm64/cortex-nightly.tar.gz aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/mac-universal-cortex-nightly.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/mac-universal/cortex-nightly.tar.gz aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/linux-amd64-cortex-nightly.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/linux-amd64/cortex-nightly.tar.gz + aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/linux-arm64-cortex-nightly.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/linux-arm64/cortex-nightly.tar.gz aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/windows-amd64-cortex-nightly.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/windows-amd64/cortex-nightly.tar.gz aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-mac-universal-network-installer.pkg s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/mac-universal/cortex-mac-universal-network-installer.pkg aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-linux-amd64-network-installer.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/linux-amd64/cortex-linux-amd64-network-installer.deb + aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-linux-arm64-network-installer.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/linux-arm64/cortex-linux-arm64-network-installer.deb aws s3 cp s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-windows-amd64-network-installer.exe s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/latest/windows-amd64/cortex-windows-amd64-network-installer.exe env: diff --git a/.github/workflows/stable-build.yml b/.github/workflows/stable-build.yml index 2b0523771..b05df983d 100644 --- a/.github/workflows/stable-build.yml +++ b/.github/workflows/stable-build.yml @@ -67,7 +67,7 @@ jobs: cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} build-linux-x64: - uses: ./.github/workflows/template-build-linux-x64.yml + uses: ./.github/workflows/template-build-linux.yml secrets: inherit needs: [get-update-version, create-draft-release, get-cortex-llamacpp-latest-version] with: @@ -79,6 +79,22 @@ jobs: channel: stable upload_url: ${{ needs.create-draft-release.outputs.upload_url }} cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: amd64 + + build-linux-arm64: + uses: ./.github/workflows/template-build-linux.yml + secrets: inherit + needs: [get-update-version, create-draft-release, get-cortex-llamacpp-latest-version] + with: + ref: ${{ github.ref }} + public_provider: github + new_version: ${{ needs.get-update-version.outputs.new_version }} + runs-on: ubuntu-2004-arm64 + cmake-flags: "-DCORTEX_VARIANT=prod -DCORTEX_CPP_VERSION='v${{ needs.get-update-version.outputs.new_version }}' -DCMAKE_TOOLCHAIN_FILE=/home/runner/actions-runner/_work/cortex.cpp/cortex.cpp/engine/vcpkg/scripts/buildsystems/vcpkg.cmake" + channel: stable + upload_url: ${{ needs.create-draft-release.outputs.upload_url }} + cortex-llamacpp-version: ${{ needs.get-cortex-llamacpp-latest-version.outputs.cortex_llamacpp_latest_version }} + arch: arm64 build-docker-x64: uses: ./.github/workflows/template-build-docker-x64.yml diff --git a/.github/workflows/template-build-linux-x64.yml b/.github/workflows/template-build-linux.yml similarity index 66% rename from .github/workflows/template-build-linux-x64.yml rename to .github/workflows/template-build-linux.yml index d1ca73844..02cc3a187 100644 --- a/.github/workflows/template-build-linux-x64.yml +++ b/.github/workflows/template-build-linux.yml @@ -1,4 +1,4 @@ -name: build-linux-x64 +name: build-linux on: workflow_call: inputs: @@ -49,6 +49,11 @@ on: type: string default: '0.0.0' description: 'The version of cortex-llamacpp to use for this job' + arch: + required: false + type: string + default: 'amd64' + description: 'The architecture to use for this job' secrets: DELTA_AWS_S3_BUCKET_NAME: required: false @@ -60,7 +65,7 @@ on: required: false jobs: - build-linux-x64: + build-linux: runs-on: ${{ inputs.runs-on }} permissions: contents: write @@ -72,6 +77,7 @@ jobs: submodules: 'recursive' - name: use python 3.9 + continue-on-error: true uses: actions/setup-python@v4 with: python-version: '3.9' @@ -124,14 +130,24 @@ jobs: - name: Configure vcpkg run: | cd engine + # Set env if arch is arm64 + if [ "${{ inputs.arch }}" == "arm64" ]; then + sudo apt install ninja-build pkg-config -y + export VCPKG_FORCE_SYSTEM_BINARIES=1 + fi make configure-vcpkg - name: Build run: | cd engine + # Set env if arch is arm64 + if [ "${{ inputs.arch }}" == "arm64" ]; then + export VCPKG_FORCE_SYSTEM_BINARIES=1 + fi make build CMAKE_EXTRA_FLAGS="${{ inputs.cmake-flags }}" BUILD_DEPS_CMAKE_EXTRA_FLAGS="${{ inputs.build-deps-cmake-flags }}" - name: Install Python + continue-on-error: true uses: actions/setup-python@v4 with: python-version: '3.10' @@ -145,28 +161,32 @@ jobs: shell: bash run: | cd engine - make build-installer PACKAGE_NAME="${{ steps.set-output-params.outputs.package_name }}" SOURCE_BINARY_PATH="../../cortex/${{ steps.set-output-params.outputs.destination_binary_name }}" SOURCE_BINARY_SERVER_PATH="../../cortex/${{ steps.set-output-params.outputs.destination_binary_server_name }}" VERSION=${{ inputs.new_version }} DESTINATION_BINARY_NAME="${{ steps.set-output-params.outputs.destination_binary_name }}" DESTINATION_BINARY_SERVER_NAME="${{ steps.set-output-params.outputs.destination_binary_server_name }}" DATA_FOLDER_NAME="${{ steps.set-output-params.outputs.data_folder_name }}" CONFIGURATION_FILE_NAME="${{ steps.set-output-params.outputs.configuration_file_name }}" UNINSTALLER_FILE_NAME="${{ steps.set-output-params.outputs.uninstaller_file_name }}" + make build-installer PACKAGE_NAME="${{ steps.set-output-params.outputs.package_name }}" SOURCE_BINARY_PATH="../../cortex/${{ steps.set-output-params.outputs.destination_binary_name }}" SOURCE_BINARY_SERVER_PATH="../../cortex/${{ steps.set-output-params.outputs.destination_binary_server_name }}" VERSION=${{ inputs.new_version }} DESTINATION_BINARY_NAME="${{ steps.set-output-params.outputs.destination_binary_name }}" DESTINATION_BINARY_SERVER_NAME="${{ steps.set-output-params.outputs.destination_binary_server_name }}" DATA_FOLDER_NAME="${{ steps.set-output-params.outputs.data_folder_name }}" CONFIGURATION_FILE_NAME="${{ steps.set-output-params.outputs.configuration_file_name }}" UNINSTALLER_FILE_NAME="${{ steps.set-output-params.outputs.uninstaller_file_name }}" ARCH="${{ inputs.arch }}" mv ${{ steps.set-output-params.outputs.package_name }}.deb ${{ steps.set-output-params.outputs.package_name }}-network.deb - name: Build local Installers run: | mkdir -p engine/templates/linux/dependencies cd engine/templates/linux/dependencies - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx-cuda-11-7.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx-cuda-12-0.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2-cuda-11-7.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2-cuda-12-0.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512-cuda-11-7.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512-cuda-12-0.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx-cuda-11-7.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx-cuda-12-0.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-vulkan.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cuda-11-7-linux-amd64.tar.gz - wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cuda-12-0-linux-amd64.tar.gz + if [ "${{ inputs.arch }}" == "amd64" ]; then + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx-cuda-11-7.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx-cuda-12-0.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2-cuda-11-7.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2-cuda-12-0.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx2.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512-cuda-11-7.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512-cuda-12-0.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-avx512.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx-cuda-11-7.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx-cuda-12-0.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-noavx.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-amd64-vulkan.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cuda-11-7-linux-amd64.tar.gz + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cuda-12-0-linux-amd64.tar.gz + else + wget https://github.com/janhq/cortex.llamacpp/releases/download/v${{ inputs.cortex-llamacpp-version }}/cortex.llamacpp-${{ inputs.cortex-llamacpp-version }}-linux-arm64.tar.gz + fi cd .. # Remove network package @@ -174,7 +194,7 @@ jobs: rm -rf ${{ steps.set-output-params.outputs.package_name }} rm ${{ steps.set-output-params.outputs.package_name }}.deb chmod +x create_deb_local.sh - ./create_deb_local.sh ${{ steps.set-output-params.outputs.package_name }} ${{ inputs.new_version }} ../../cortex/${{ steps.set-output-params.outputs.destination_binary_name }} ../../cortex/${{ steps.set-output-params.outputs.destination_binary_server_name }} ${{ steps.set-output-params.outputs.destination_binary_name }} ${{ steps.set-output-params.outputs.destination_binary_server_name }} ${{ steps.set-output-params.outputs.data_folder_name }} ${{ steps.set-output-params.outputs.configuration_file_name }}; + ./create_deb_local.sh ${{ steps.set-output-params.outputs.package_name }} ${{ inputs.new_version }} ../../cortex/${{ steps.set-output-params.outputs.destination_binary_name }} ../../cortex/${{ steps.set-output-params.outputs.destination_binary_server_name }} ${{ steps.set-output-params.outputs.destination_binary_name }} ${{ steps.set-output-params.outputs.destination_binary_server_name }} ${{ steps.set-output-params.outputs.data_folder_name }} ${{ steps.set-output-params.outputs.configuration_file_name }} ${{ inputs.arch }}; cp ${{ steps.set-output-params.outputs.package_name }}.deb ../../${{ steps.set-output-params.outputs.package_name }}-local.deb - name: Package @@ -185,30 +205,30 @@ jobs: - name: Upload Artifact uses: actions/upload-artifact@v4 with: - name: cortex-${{ inputs.new_version }}-linux-amd64 + name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }} path: ./engine/cortex - name: Upload Artifact uses: actions/upload-artifact@v4 with: - name: cortex-${{ inputs.new_version }}-linux-amd64-network-installer + name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-network-installer path: ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb - name: Upload Artifact uses: actions/upload-artifact@v4 with: - name: cortex-${{ inputs.new_version }}-linux-amd64-local-installer + name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-local-installer path: ./engine/${{ steps.set-output-params.outputs.package_name }}-local.deb - name: upload to aws s3 if public provider is aws if: inputs.public_provider == 'aws-s3' run: | - aws s3 cp ./engine/cortex.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/linux-amd64-cortex-nightly.tar.gz - aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-linux-amd64-network-installer.deb + aws s3 cp ./engine/cortex.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/linux-${{ inputs.arch }}-cortex-nightly.tar.gz + aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/temp-latest/cortex-linux-${{ inputs.arch }}-network-installer.deb - aws s3 cp ./engine/cortex.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-amd64/cortex-nightly.tar.gz - aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-amd64/cortex-${{ inputs.new_version }}-linux-amd64-network-installer.deb - aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-local.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-amd64/cortex-${{ inputs.new_version }}-linux-amd64-local-installer.deb + aws s3 cp ./engine/cortex.tar.gz s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-${{ inputs.arch }}/cortex-nightly.tar.gz + aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-${{ inputs.arch }}/cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-network-installer.deb + aws s3 cp ./engine/${{ steps.set-output-params.outputs.package_name }}-local.deb s3://${{ secrets.DELTA_AWS_S3_BUCKET_NAME }}/cortex/v${{ inputs.new_version }}/linux-${{ inputs.arch }}/cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-local-installer.deb env: AWS_ACCESS_KEY_ID: ${{ secrets.DELTA_AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.DELTA_AWS_SECRET_ACCESS_KEY }} @@ -223,7 +243,7 @@ jobs: with: upload_url: ${{ inputs.upload_url }} asset_path: ./engine/cortex.tar.gz - asset_name: cortex-${{ inputs.new_version }}-linux-amd64.tar.gz + asset_name: cortex-${{ inputs.new_version }}-linux${{ inputs.arch }}.tar.gz asset_content_type: application/zip - name: Upload release assert if public provider is github @@ -234,7 +254,7 @@ jobs: with: upload_url: ${{ inputs.upload_url }} asset_path: ./engine/${{ steps.set-output-params.outputs.package_name }}-network.deb - asset_name: cortex-${{ inputs.new_version }}-linux-amd64-network-installer.deb + asset_name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-network-installer.deb asset_content_type: application/octet-stream - name: Upload release assert if public provider is github @@ -245,5 +265,5 @@ jobs: with: upload_url: ${{ inputs.upload_url }} asset_path: ./engine/${{ steps.set-output-params.outputs.package_name }}-local.deb - asset_name: cortex-${{ inputs.new_version }}-linux-amd64-local-installer.deb + asset_name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}-local-installer.deb asset_content_type: application/octet-stream \ No newline at end of file diff --git a/engine/Makefile b/engine/Makefile index 8f27eebcc..a6a4b5a79 100644 --- a/engine/Makefile +++ b/engine/Makefile @@ -24,6 +24,7 @@ DESTINATION_BINARY_SERVER_NAME ?= cortex-server DATA_FOLDER_NAME ?= .cortex CONFIGURATION_FILE_NAME ?= .cortexrc UNINSTALLER_FILE_NAME ?= cortex-uninstall.sh +ARCH ?= amd64 # Default target, does nothing all: @@ -120,7 +121,7 @@ else ifeq ($(shell uname -s),Linux) @echo "Building installer for linux"; \ cd templates/linux; \ chmod +x create_deb.sh; \ - ./create_deb.sh $(PACKAGE_NAME) $(VERSION) $(SOURCE_BINARY_PATH) $(SOURCE_BINARY_SERVER_PATH) $(DESTINATION_BINARY_NAME) $(DESTINATION_BINARY_SERVER_NAME) $(DATA_FOLDER_NAME) $(CONFIGURATION_FILE_NAME); \ + ./create_deb.sh $(PACKAGE_NAME) $(VERSION) $(SOURCE_BINARY_PATH) $(SOURCE_BINARY_SERVER_PATH) $(DESTINATION_BINARY_NAME) $(DESTINATION_BINARY_SERVER_NAME) $(DATA_FOLDER_NAME) $(CONFIGURATION_FILE_NAME) $(ARCH); \ cp $(PACKAGE_NAME).deb ../../ else @echo "Building installer for Macos"; \ diff --git a/engine/e2e-test/test_api_engine.py b/engine/e2e-test/test_api_engine.py index 57b47b879..e652e4495 100644 --- a/engine/e2e-test/test_api_engine.py +++ b/engine/e2e-test/test_api_engine.py @@ -28,14 +28,14 @@ def test_engines_get_llamacpp_should_be_successful(self): # engines install def test_engines_install_llamacpp_specific_version_and_variant(self): - data = {"version": "v0.1.35-27.10.24", "variant": "linux-amd64-avx-cuda-11-7"} + data = {"version": "v0.1.40-b4354", "variant": "linux-amd64-avx-cuda-11-7"} response = requests.post( "http://localhost:3928/v1/engines/llama-cpp/install", json=data ) assert response.status_code == 200 def test_engines_install_llamacpp_specific_version_and_null_variant(self): - data = {"version": "v0.1.35-27.10.24"} + data = {"version": "v0.1.40-b4354"} response = requests.post( "http://localhost:3928/v1/engines/llama-cpp/install", json=data ) diff --git a/engine/e2e-test/test_api_engine_update.py b/engine/e2e-test/test_api_engine_update.py index 23939f038..be8685dba 100644 --- a/engine/e2e-test/test_api_engine_update.py +++ b/engine/e2e-test/test_api_engine_update.py @@ -25,7 +25,7 @@ def setup_and_teardown(self): @pytest.mark.asyncio async def test_engines_update_should_be_successfully(self): - requests.post("http://localhost:3928/v1/engines/llama-cpp?version=0.1.34") + requests.post("http://localhost:3928/v1/engines/llama-cpp?version=0.1.43") response = requests.post("http://localhost:3928/v1/engines/llama-cpp/update") assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model.py b/engine/e2e-test/test_api_model.py index 8f2e4b07a..d75aa6831 100644 --- a/engine/e2e-test/test_api_model.py +++ b/engine/e2e-test/test_api_model.py @@ -85,6 +85,7 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self): ], ) + @pytest.mark.asyncio async def test_models_start_stop_should_be_successful(self): print("Install engine") response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py index a998f3183..dbbc16e8a 100644 --- a/engine/e2e-test/test_cli_engine_install.py +++ b/engine/e2e-test/test_cli_engine_install.py @@ -49,7 +49,7 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self): @pytest.mark.skipif(platform.system() == "Windows", reason="Progress bar log issue on Windows") def test_engines_install_pre_release_llamacpp(self): - engine_version = "v0.1.29" + engine_version = "v0.1.43" exit_code, output, error = run( "Install Engine", ["engines", "install", "llama-cpp", "-v", engine_version], @@ -69,7 +69,7 @@ def test_engines_install_pre_release_llamacpp(self): assert is_engine_version_exist, f"Engine version {engine_version} is not found" assert exit_code == 0, f"Install engine failed with error: {error}" - @pytest.mark.skipif(platform.system() == "Windows", reason="Progress bar log issue on Windows") + @pytest.mark.skipif(platform.system() == "Windows" or platform.system() == "Linux", reason="Progress bar log issue on Windows") def test_engines_should_fallback_to_download_llamacpp_engine_if_not_exists(self): exit_code, output, error = run( "Install Engine", diff --git a/engine/e2e-test/test_cli_engine_install_nightly.py b/engine/e2e-test/test_cli_engine_install_nightly.py index 8c66c284c..bbb56ac9b 100644 --- a/engine/e2e-test/test_cli_engine_install_nightly.py +++ b/engine/e2e-test/test_cli_engine_install_nightly.py @@ -47,6 +47,7 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self): assert "is not supported on" in output, "Should display error message" assert exit_code == 0, f"Install engine failed with error: {error}" + @pytest.mark.skipif(platform.system() == "Linux", reason="Wait for linux arm ready") def test_engines_should_fallback_to_download_llamacpp_engine_if_not_exists(self): exit_code, output, error = run( "Install Engine", diff --git a/engine/e2e-test/test_cli_engine_uninstall.py b/engine/e2e-test/test_cli_engine_uninstall.py index fcc5f5c73..6b640d45d 100644 --- a/engine/e2e-test/test_cli_engine_uninstall.py +++ b/engine/e2e-test/test_cli_engine_uninstall.py @@ -24,7 +24,7 @@ def setup_and_teardown(self): @pytest.mark.asyncio async def test_engines_uninstall_llamacpp_should_be_successfully(self): - requests.post("http://127.0.0.1:3928/v1/engines/llama-cpp/install") + response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") await wait_for_websocket_download_success_event(timeout=None) exit_code, output, error = run( "Uninstall engine", ["engines", "uninstall", "llama-cpp"] diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index d855c8f61..a38dbe70e 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -140,10 +140,10 @@ cpp::result DownloadService::GetFileSize( curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); auto headers = curl_utils::GetHeaders(url); - if (headers.has_value()) { + if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -227,10 +227,10 @@ cpp::result DownloadService::Download( curl_easy_setopt(curl, CURLOPT_URL, download_item.downloadUrl.c_str()); auto headers = curl_utils::GetHeaders(download_item.downloadUrl); - if (headers.has_value()) { + if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -469,7 +469,7 @@ void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item, auto headers = curl_utils::GetHeaders(item.downloadUrl); if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { curl_headers = curl_slist_append(curl_headers, (key + ": " + value).c_str()); } diff --git a/engine/templates/linux/control b/engine/templates/linux/control index e877fe5ab..7c129a690 100644 --- a/engine/templates/linux/control +++ b/engine/templates/linux/control @@ -2,7 +2,7 @@ Package: $PACKAGE_NAME Version: $VERSION Section: base Priority: optional -Architecture: amd64 +Architecture: $ARCH Depends: openmpi-bin,libopenmpi-dev Maintainer: Homebrew Computer Pte Ltd Description: Cortex diff --git a/engine/templates/linux/create_deb.sh b/engine/templates/linux/create_deb.sh index 29492bdd8..f9247972f 100644 --- a/engine/templates/linux/create_deb.sh +++ b/engine/templates/linux/create_deb.sh @@ -6,6 +6,7 @@ DESTINATION_BINARY_NAME=$5 DESTINATION_BINARY_SERVER_NAME=$6 DATA_FOLDER_NAME=$7 CONFIGURATION_FILE_NAME=$8 +ARCH=$9 mkdir -p $PACKAGE_NAME/DEBIAN @@ -31,8 +32,9 @@ chmod 755 $PACKAGE_NAME/DEBIAN/postinst chmod 755 $PACKAGE_NAME/DEBIAN/postrm chmod 755 $PACKAGE_NAME/DEBIAN/prerm -export PACKAGE_NAME VERSION +export PACKAGE_NAME VERSION ARCH envsubst < control > $PACKAGE_NAME/DEBIAN/control +sed -i "s/ARCH/$ARCH/" $PACKAGE_NAME/DEBIAN/control dpkg-deb --build $PACKAGE_NAME $PACKAGE_NAME.deb diff --git a/engine/templates/linux/create_deb_local.sh b/engine/templates/linux/create_deb_local.sh index 6b54dc19d..0ab0c79c5 100644 --- a/engine/templates/linux/create_deb_local.sh +++ b/engine/templates/linux/create_deb_local.sh @@ -6,6 +6,7 @@ DESTINATION_BINARY_NAME=$5 DESTINATION_BINARY_SERVER_NAME=$6 DATA_FOLDER_NAME=$7 CONFIGURATION_FILE_NAME=$8 +ARCH=$9 mkdir -p $PACKAGE_NAME/DEBIAN @@ -34,7 +35,7 @@ chmod 755 $PACKAGE_NAME/DEBIAN/postinst chmod 755 $PACKAGE_NAME/DEBIAN/postrm chmod 755 $PACKAGE_NAME/DEBIAN/prerm -export PACKAGE_NAME VERSION +export PACKAGE_NAME VERSION ARCH envsubst < control > $PACKAGE_NAME/DEBIAN/control diff --git a/engine/templates/linux/install.sh b/engine/templates/linux/install.sh index e11b879c6..ade0b134a 100644 --- a/engine/templates/linux/install.sh +++ b/engine/templates/linux/install.sh @@ -6,6 +6,17 @@ if [ "$(id -u)" != "0" ]; then exit 1 fi +# Determine architecture +ARCH=$(uname -m) +if [ "$ARCH" = "x86_64" ]; then + ARCH="amd64" +elif [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" +else + echo "Unsupported architecture: $ARCH" + exit 1 +fi + # Determine the home directory based on the user USER_TO_RUN_AS=${SUDO_USER:-$(whoami)} if [ "$USER_TO_RUN_AS" = "root" ]; then @@ -142,19 +153,19 @@ install_cortex() { case $channel in stable) - url_binary="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64.tar.gz" - url_deb_local="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64-local-installer.deb" - url_deb_network="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64-network-installer.deb" + url_binary="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}.tar.gz" + url_deb_local="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}-local-installer.deb" + url_deb_network="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}-network-installer.deb" ;; beta) - url_binary="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64.tar.gz" - url_deb_local="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64-local-installer.deb" - url_deb_network="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-amd64-network-installer.deb" + url_binary="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}.tar.gz" + url_deb_local="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}-local-installer.deb" + url_deb_network="https://github.com/janhq/cortex.cpp/releases/download/v${version}/cortex-${version}-linux-${ARCH}-network-installer.deb" ;; nightly) - url_binary="https://delta.jan.ai/cortex/v${version}/linux-amd64/cortex-nightly.tar.gz" - url_deb_local="https://delta.jan.ai/cortex/v${version}/linux-amd64/cortex-${version}-linux-amd64-local-installer.deb" - url_deb_network="https://delta.jan.ai/cortex/v${version}/linux-amd64/cortex-${version}-linux-amd64-network-installer.deb" + url_binary="https://delta.jan.ai/cortex/v${version}/linux-${ARCH}/cortex-nightly.tar.gz" + url_deb_local="https://delta.jan.ai/cortex/v${version}/linux-${ARCH}/cortex-${version}-linux-${ARCH}-local-installer.deb" + url_deb_network="https://delta.jan.ai/cortex/v${version}/linux-${ARCH}/cortex-${version}-linux-${ARCH}-network-installer.deb" ;; esac diff --git a/engine/test/components/test_engine_matcher_utils.cc b/engine/test/components/test_engine_matcher_utils.cc index 7da4e3cd1..1d1ed47a8 100644 --- a/engine/test/components/test_engine_matcher_utils.cc +++ b/engine/test/components/test_engine_matcher_utils.cc @@ -19,6 +19,7 @@ class EngineMatcherUtilsTestSuite : public ::testing::Test { "cortex.llamacpp-0.1.25-25.08.24-linux-amd64-noavx-cuda-12-0.tar.gz", "cortex.llamacpp-0.1.25-25.08.24-linux-amd64-noavx.tar.gz", "cortex.llamacpp-0.1.25-25.08.24-linux-amd64-vulkan.tar.gz", + "cortex.llamacpp-0.1.43-linux-arm64.tar.gz", "cortex.llamacpp-0.1.25-25.08.24-mac-amd64.tar.gz", "cortex.llamacpp-0.1.25-25.08.24-mac-arm64.tar.gz", "cortex.llamacpp-0.1.25-25.08.24-windows-amd64-avx-cuda-11-7.tar.gz", @@ -134,6 +135,18 @@ TEST_F(EngineMatcherUtilsTestSuite, TestValidate) { EXPECT_EQ(variant, "cortex.llamacpp-0.1.25-25.08.24-windows-amd64-avx2.tar.gz"); } + + { + auto os{"linux"}; + auto cpu_arch{"arm64"}; + auto suitable_avx{""}; + auto cuda_version{""}; + + auto variant = engine_matcher_utils::Validate( + cortex_llamacpp_variants, os, cpu_arch, suitable_avx, cuda_version); + + EXPECT_EQ(variant, "cortex.llamacpp-0.1.43-linux-arm64.tar.gz"); + } } TEST_F(EngineMatcherUtilsTestSuite, TestGetVersionAndArch) { diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc index d5945e8c8..2481658ad 100644 --- a/engine/utils/curl_utils.cc +++ b/engine/utils/curl_utils.cc @@ -9,12 +9,24 @@ namespace curl_utils { namespace { -size_t WriteCallback(void* contents, size_t size, size_t nmemb, - std::string* output) { - size_t totalSize = size * nmemb; - output->append((char*)contents, totalSize); - return totalSize; -} +class CurlResponse { + public: + static size_t WriteCallback(char* buffer, size_t size, size_t nitems, + void* userdata) { + auto* response = static_cast(userdata); + return response->Append(buffer, size * nitems); + } + + size_t Append(const char* buffer, size_t size) { + data_.append(buffer, size); + return size; + } + + const std::string& GetData() const { return data_; } + + private: + std::string data_; +}; void SetUpProxy(CURL* handle, const std::string& url) { auto config = file_manager_utils::GetCortexConfig(); @@ -59,19 +71,18 @@ void SetUpProxy(CURL* handle, const std::string& url) { } } // namespace -std::optional> GetHeaders( - const std::string& url) { +std::shared_ptr
GetHeaders(const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.has_error()) { - return std::nullopt; + return nullptr; } if (url_obj->host == kHuggingFaceHost) { - std::unordered_map headers{}; - headers["Content-Type"] = "application/json"; + auto headers = std::make_shared
(); + headers->m["Content-Type"] = "application/json"; auto const& token = file_manager_utils::GetCortexConfig().huggingFaceToken; if (!token.empty()) { - headers["Authorization"] = "Bearer " + token; + headers->m["Authorization"] = "Bearer " + token; // for debug purpose auto min_token_size = 6; @@ -87,15 +98,15 @@ std::optional> GetHeaders( } if (url_obj->host == kGitHubHost) { - std::unordered_map headers{}; - headers["Accept"] = "application/vnd.github.v3+json"; + auto headers = std::make_shared
(); + headers->m["Accept"] = "application/vnd.github.v3+json"; // github API requires user-agent https://docs.github.com/en/rest/using-the-rest-api/getting-started-with-the-rest-api?apiVersion=2022-11-28#user-agent auto user_agent = file_manager_utils::GetCortexConfig().gitHubUserAgent; auto gh_token = file_manager_utils::GetCortexConfig().gitHubToken; - headers["User-Agent"] = + headers->m["User-Agent"] = user_agent.empty() ? kDefaultGHUserAgent : user_agent; if (!gh_token.empty()) { - headers["Authorization"] = "Bearer " + gh_token; + headers->m["Authorization"] = "Bearer " + gh_token; // for debug purpose auto min_token_size = 6; @@ -109,7 +120,7 @@ std::optional> GetHeaders( return headers; } - return std::nullopt; + return nullptr; } cpp::result SimpleGet(const std::string& url, @@ -122,8 +133,8 @@ cpp::result SimpleGet(const std::string& url, auto headers = GetHeaders(url); curl_slist* curl_headers = nullptr; - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { + if (headers) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -131,12 +142,14 @@ cpp::result SimpleGet(const std::string& url, curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); } - std::string readBuffer; + auto* response = new CurlResponse(); + std::shared_ptr s(response, + std::default_delete()); SetUpProxy(curl, url); curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlResponse::WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); if (timeout > 0) { curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout); } @@ -155,10 +168,10 @@ cpp::result SimpleGet(const std::string& url, if (http_code >= 400) { CTL_ERR("HTTP request failed with status code: " + std::to_string(http_code)); - return cpp::fail(readBuffer); + return cpp::fail(response->GetData()); } - return readBuffer; + return response->GetData(); } cpp::result SimpleRequest( @@ -176,13 +189,15 @@ cpp::result SimpleRequest( curl_slist_append(curl_headers, "Content-Type: application/json"); curl_headers = curl_slist_append(curl_headers, "Expect:"); - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { + if (headers) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } } - std::string readBuffer; + auto* response = new CurlResponse(); + std::shared_ptr s(response, + std::default_delete()); SetUpProxy(curl, url); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); @@ -196,8 +211,8 @@ cpp::result SimpleRequest( curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "DELETE"); } curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlResponse::WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); @@ -221,10 +236,10 @@ cpp::result SimpleRequest( if (http_code >= 400) { CTL_ERR("HTTP request failed with status code: " + std::to_string(http_code)); - return cpp::fail(readBuffer); + return cpp::fail(response->GetData()); } - return readBuffer; + return response->GetData(); } cpp::result ReadRemoteYaml(const std::string& url) { diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index f33b7e8e5..9035b6b3c 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -15,8 +15,11 @@ enum class RequestType { GET, PATCH, POST, DEL }; namespace curl_utils { -std::optional> GetHeaders( - const std::string& url); +struct Header { + std::unordered_map m; +}; + +std::shared_ptr
GetHeaders(const std::string& url); cpp::result SimpleGet(const std::string& url, const int timeout = -1); diff --git a/engine/utils/engine_matcher_utils.h b/engine/utils/engine_matcher_utils.h index a6135e532..28c0f0c2a 100644 --- a/engine/utils/engine_matcher_utils.h +++ b/engine/utils/engine_matcher_utils.h @@ -156,6 +156,11 @@ inline std::string Validate(const std::vector& variants, if (os == "mac" && !os_and_arch_compatible_list.empty()) return os_and_arch_compatible_list[0]; + if (os == "linux" && cpu_arch == "arm64" && + !os_and_arch_compatible_list.empty()) { + return os_and_arch_compatible_list[0]; + } + std::vector avx_compatible_list; std::copy_if(os_and_arch_compatible_list.begin(), From c508e6846b914030c809572f3aeef411dda002b0 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 13 Jan 2025 08:59:46 +0700 Subject: [PATCH 15/16] fix: add cpu_threads to model.yaml (#1845) Co-authored-by: vansangpfiev --- engine/cli/command_line_parser.cc | 1 + engine/cli/commands/model_upd_cmd.cc | 6 + engine/config/model_config.h | 8 ++ engine/config/yaml_config.cc | 124 ++++++++++---------- engine/test/components/test_format_utils.cc | 12 +- engine/test/components/test_yaml_handler.cc | 6 + engine/utils/format_utils.h | 18 +-- 7 files changed, 101 insertions(+), 74 deletions(-) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 6f8f227e6..b423a6896 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -908,6 +908,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) { "ngl", "ctx_len", "n_parallel", + "cpu_threads", "engine", "prompt_template", "system_template", diff --git a/engine/cli/commands/model_upd_cmd.cc b/engine/cli/commands/model_upd_cmd.cc index 6534d1fbd..1572581ec 100644 --- a/engine/cli/commands/model_upd_cmd.cc +++ b/engine/cli/commands/model_upd_cmd.cc @@ -228,6 +228,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key, data["n_parallel"] = static_cast(f); }); }}, + {"cpu_threads", + [this](Json::Value &data, const std::string& k, const std::string& v) { + UpdateNumericField(k, v, [&data](float f) { + data["cpu_threads"] = static_cast(f); + }); + }}, {"tp", [this](Json::Value &data, const std::string& k, const std::string& v) { UpdateNumericField(k, v, [&data](float f) { diff --git a/engine/config/model_config.h b/engine/config/model_config.h index d8ede92f7..ea671354e 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -164,6 +164,7 @@ struct ModelConfig { int ngl = std::numeric_limits::quiet_NaN(); int ctx_len = std::numeric_limits::quiet_NaN(); int n_parallel = 1; + int cpu_threads = -1; std::string engine; std::string prompt_template; std::string system_template; @@ -272,6 +273,8 @@ struct ModelConfig { ctx_len = json["ctx_len"].asInt(); if (json.isMember("n_parallel")) n_parallel = json["n_parallel"].asInt(); + if (json.isMember("cpu_threads")) + cpu_threads = json["cpu_threads"].asInt(); if (json.isMember("engine")) engine = json["engine"].asString(); if (json.isMember("prompt_template")) @@ -362,6 +365,9 @@ struct ModelConfig { obj["ngl"] = ngl; obj["ctx_len"] = ctx_len; obj["n_parallel"] = n_parallel; + if (cpu_threads > 0) { + obj["cpu_threads"] = cpu_threads; + } obj["engine"] = engine; obj["prompt_template"] = prompt_template; obj["system_template"] = system_template; @@ -474,6 +480,8 @@ struct ModelConfig { format_utils::MAGENTA); oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel), format_utils::MAGENTA); + oss << format_utils::print_kv("cpu_threads", std::to_string(cpu_threads), + format_utils::MAGENTA); if (ngl != std::numeric_limits::quiet_NaN()) oss << format_utils::print_kv("ngl", std::to_string(ngl), format_utils::MAGENTA); diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc index bbe7f430c..57b2b3ecb 100644 --- a/engine/config/yaml_config.cc +++ b/engine/config/yaml_config.cc @@ -119,6 +119,8 @@ void YamlHandler::ModelConfigFromYaml() { tmp.ctx_len = yaml_node_["ctx_len"].as(); if (yaml_node_["n_parallel"]) tmp.n_parallel = yaml_node_["n_parallel"].as(); + if (yaml_node_["cpu_threads"]) + tmp.cpu_threads = yaml_node_["cpu_threads"].as(); if (yaml_node_["tp"]) tmp.tp = yaml_node_["tp"].as(); if (yaml_node_["stream"]) @@ -224,6 +226,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { yaml_node_["ctx_len"] = model_config_.ctx_len; if (!std::isnan(static_cast(model_config_.n_parallel))) yaml_node_["n_parallel"] = model_config_.n_parallel; + if (!std::isnan(static_cast(model_config_.cpu_threads))) + yaml_node_["cpu_threads"] = model_config_.cpu_threads; if (!std::isnan(static_cast(model_config_.tp))) yaml_node_["tp"] = model_config_.tp; if (!std::isnan(static_cast(model_config_.stream))) @@ -283,110 +287,112 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { // Method to write all attributes to a YAML file void YamlHandler::WriteYamlFile(const std::string& file_path) const { try { - std::ofstream outFile(file_path); - if (!outFile) { + std::ofstream out_file(file_path); + if (!out_file) { throw std::runtime_error("Failed to open output file."); } // Write GENERAL GGUF METADATA - outFile << "# BEGIN GENERAL GGUF METADATA\n"; - outFile << format_utils::writeKeyValue( + out_file << "# BEGIN GENERAL GGUF METADATA\n"; + out_file << format_utils::WriteKeyValue( "id", yaml_node_["id"], "Model ID unique between models (author / quantization)"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "model", yaml_node_["model"], "Model ID which is used for request construct - should be " "unique between models (author / quantization)"); - outFile << format_utils::writeKeyValue("name", yaml_node_["name"], + out_file << format_utils::WriteKeyValue("name", yaml_node_["name"], "metadata.general.name"); if (yaml_node_["version"]) { - outFile << "version: " << yaml_node_["version"].as() << "\n"; + out_file << "version: " << yaml_node_["version"].as() << "\n"; } if (yaml_node_["files"] && yaml_node_["files"].size()) { - outFile << "files: # Can be relative OR absolute local file " + out_file << "files: # Can be relative OR absolute local file " "path\n"; for (const auto& source : yaml_node_["files"]) { - outFile << " - " << source << "\n"; + out_file << " - " << source << "\n"; } } - outFile << "# END GENERAL GGUF METADATA\n"; - outFile << "\n"; + out_file << "# END GENERAL GGUF METADATA\n"; + out_file << "\n"; // Write INFERENCE PARAMETERS - outFile << "# BEGIN INFERENCE PARAMETERS\n"; - outFile << "# BEGIN REQUIRED\n"; + out_file << "# BEGIN INFERENCE PARAMETERS\n"; + out_file << "# BEGIN REQUIRED\n"; if (yaml_node_["stop"] && yaml_node_["stop"].size()) { - outFile << "stop: # tokenizer.ggml.eos_token_id\n"; + out_file << "stop: # tokenizer.ggml.eos_token_id\n"; for (const auto& stop : yaml_node_["stop"]) { - outFile << " - " << stop << "\n"; + out_file << " - " << stop << "\n"; } } - outFile << "# END REQUIRED\n"; - outFile << "\n"; - outFile << "# BEGIN OPTIONAL\n"; - outFile << format_utils::writeKeyValue("size", yaml_node_["size"]); - outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"], + out_file << "# END REQUIRED\n"; + out_file << "\n"; + out_file << "# BEGIN OPTIONAL\n"; + out_file << format_utils::WriteKeyValue("size", yaml_node_["size"]); + out_file << format_utils::WriteKeyValue("stream", yaml_node_["stream"], "Default true?"); - outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"], + out_file << format_utils::WriteKeyValue("top_p", yaml_node_["top_p"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "temperature", yaml_node_["temperature"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "frequency_penalty", yaml_node_["frequency_penalty"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "presence_penalty", yaml_node_["presence_penalty"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "max_tokens", yaml_node_["max_tokens"], "Should be default to context length"); - outFile << format_utils::writeKeyValue("seed", yaml_node_["seed"]); - outFile << format_utils::writeKeyValue("dynatemp_range", + out_file << format_utils::WriteKeyValue("seed", yaml_node_["seed"]); + out_file << format_utils::WriteKeyValue("dynatemp_range", yaml_node_["dynatemp_range"]); - outFile << format_utils::writeKeyValue("dynatemp_exponent", + out_file << format_utils::WriteKeyValue("dynatemp_exponent", yaml_node_["dynatemp_exponent"]); - outFile << format_utils::writeKeyValue("top_k", yaml_node_["top_k"]); - outFile << format_utils::writeKeyValue("min_p", yaml_node_["min_p"]); - outFile << format_utils::writeKeyValue("tfs_z", yaml_node_["tfs_z"]); - outFile << format_utils::writeKeyValue("typ_p", yaml_node_["typ_p"]); - outFile << format_utils::writeKeyValue("repeat_last_n", + out_file << format_utils::WriteKeyValue("top_k", yaml_node_["top_k"]); + out_file << format_utils::WriteKeyValue("min_p", yaml_node_["min_p"]); + out_file << format_utils::WriteKeyValue("tfs_z", yaml_node_["tfs_z"]); + out_file << format_utils::WriteKeyValue("typ_p", yaml_node_["typ_p"]); + out_file << format_utils::WriteKeyValue("repeat_last_n", yaml_node_["repeat_last_n"]); - outFile << format_utils::writeKeyValue("repeat_penalty", + out_file << format_utils::WriteKeyValue("repeat_penalty", yaml_node_["repeat_penalty"]); - outFile << format_utils::writeKeyValue("mirostat", yaml_node_["mirostat"]); - outFile << format_utils::writeKeyValue("mirostat_tau", + out_file << format_utils::WriteKeyValue("mirostat", yaml_node_["mirostat"]); + out_file << format_utils::WriteKeyValue("mirostat_tau", yaml_node_["mirostat_tau"]); - outFile << format_utils::writeKeyValue("mirostat_eta", + out_file << format_utils::WriteKeyValue("mirostat_eta", yaml_node_["mirostat_eta"]); - outFile << format_utils::writeKeyValue("penalize_nl", + out_file << format_utils::WriteKeyValue("penalize_nl", yaml_node_["penalize_nl"]); - outFile << format_utils::writeKeyValue("ignore_eos", + out_file << format_utils::WriteKeyValue("ignore_eos", yaml_node_["ignore_eos"]); - outFile << format_utils::writeKeyValue("n_probs", yaml_node_["n_probs"]); - outFile << format_utils::writeKeyValue("min_keep", yaml_node_["min_keep"]); - outFile << format_utils::writeKeyValue("grammar", yaml_node_["grammar"]); - outFile << "# END OPTIONAL\n"; - outFile << "# END INFERENCE PARAMETERS\n"; - outFile << "\n"; + out_file << format_utils::WriteKeyValue("n_probs", yaml_node_["n_probs"]); + out_file << format_utils::WriteKeyValue("min_keep", yaml_node_["min_keep"]); + out_file << format_utils::WriteKeyValue("grammar", yaml_node_["grammar"]); + out_file << "# END OPTIONAL\n"; + out_file << "# END INFERENCE PARAMETERS\n"; + out_file << "\n"; // Write MODEL LOAD PARAMETERS - outFile << "# BEGIN MODEL LOAD PARAMETERS\n"; - outFile << "# BEGIN REQUIRED\n"; - outFile << format_utils::writeKeyValue("engine", yaml_node_["engine"], + out_file << "# BEGIN MODEL LOAD PARAMETERS\n"; + out_file << "# BEGIN REQUIRED\n"; + out_file << format_utils::WriteKeyValue("engine", yaml_node_["engine"], "engine to run model"); - outFile << "prompt_template:"; - outFile << " " << yaml_node_["prompt_template"] << "\n"; - outFile << "# END REQUIRED\n"; - outFile << "\n"; - outFile << "# BEGIN OPTIONAL\n"; - outFile << format_utils::writeKeyValue( + out_file << "prompt_template:"; + out_file << " " << yaml_node_["prompt_template"] << "\n"; + out_file << "# END REQUIRED\n"; + out_file << "\n"; + out_file << "# BEGIN OPTIONAL\n"; + out_file << format_utils::WriteKeyValue( "ctx_len", yaml_node_["ctx_len"], "llama.context_length | 0 or undefined = loaded from model"); - outFile << format_utils::writeKeyValue("n_parallel", + out_file << format_utils::WriteKeyValue("n_parallel", yaml_node_["n_parallel"]); - outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"], + out_file << format_utils::WriteKeyValue("cpu_threads", + yaml_node_["cpu_threads"]); + out_file << format_utils::WriteKeyValue("ngl", yaml_node_["ngl"], "Undefined = loaded from model"); - outFile << "# END OPTIONAL\n"; - outFile << "# END MODEL LOAD PARAMETERS\n"; + out_file << "# END OPTIONAL\n"; + out_file << "# END MODEL LOAD PARAMETERS\n"; - outFile.close(); + out_file.close(); } catch (const std::exception& e) { std::cerr << "Error writing to file: " << e.what() << std::endl; throw; diff --git a/engine/test/components/test_format_utils.cc b/engine/test/components/test_format_utils.cc index cd777d5fa..d279b5940 100644 --- a/engine/test/components/test_format_utils.cc +++ b/engine/test/components/test_format_utils.cc @@ -9,37 +9,37 @@ TEST_F(FormatUtilsTest, WriteKeyValue) { { YAML::Node node; std::string result = - format_utils::writeKeyValue("key", node["does_not_exist"]); + format_utils::WriteKeyValue("key", node["does_not_exist"]); EXPECT_EQ(result, ""); } { YAML::Node node = YAML::Load("value"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: value\n"); } { YAML::Node node = YAML::Load("3.14159"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3.14159\n"); } { YAML::Node node = YAML::Load("3.000000"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3\n"); } { YAML::Node node = YAML::Load("3.140000"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3.14\n"); } { YAML::Node node = YAML::Load("value"); - std::string result = format_utils::writeKeyValue("key", node, "comment"); + std::string result = format_utils::WriteKeyValue("key", node, "comment"); EXPECT_EQ(result, "key: value # comment\n"); } } diff --git a/engine/test/components/test_yaml_handler.cc b/engine/test/components/test_yaml_handler.cc index f699e0c6a..c7e4b6a21 100644 --- a/engine/test/components/test_yaml_handler.cc +++ b/engine/test/components/test_yaml_handler.cc @@ -63,6 +63,7 @@ temperature: 0.7 max_tokens: 100 stream: true n_parallel: 2 +cpu_threads: 3 stop: - "END" files: @@ -84,6 +85,7 @@ n_parallel: 2 EXPECT_EQ(config.max_tokens, 100); EXPECT_TRUE(config.stream); EXPECT_EQ(config.n_parallel, 2); + EXPECT_EQ(config.cpu_threads, 3); EXPECT_EQ(config.stop.size(), 1); EXPECT_EQ(config.stop[0], "END"); EXPECT_EQ(config.files.size(), 1); @@ -104,6 +106,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { new_config.max_tokens = 200; new_config.stream = false; new_config.n_parallel = 2; + new_config.cpu_threads = 3; new_config.stop = {"STOP", "END"}; new_config.files = {"updated_file1.gguf", "updated_file2.gguf"}; @@ -120,6 +123,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { EXPECT_EQ(config.max_tokens, 200); EXPECT_FALSE(config.stream); EXPECT_EQ(config.n_parallel, 2); + EXPECT_EQ(config.cpu_threads, 3); EXPECT_EQ(config.stop.size(), 2); EXPECT_EQ(config.stop[0], "STOP"); EXPECT_EQ(config.stop[1], "END"); @@ -140,6 +144,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { new_config.max_tokens = 150; new_config.stream = true; new_config.n_parallel = 2; + new_config.cpu_threads = 3; new_config.stop = {"HALT"}; new_config.files = {"write_test_file.gguf"}; @@ -164,6 +169,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { EXPECT_EQ(read_config.max_tokens, 150); EXPECT_TRUE(read_config.stream); EXPECT_EQ(read_config.n_parallel, 2); + EXPECT_EQ(read_config.cpu_threads, 3); EXPECT_EQ(read_config.stop.size(), 1); EXPECT_EQ(read_config.stop[0], "HALT"); EXPECT_EQ(read_config.files.size(), 1); diff --git a/engine/utils/format_utils.h b/engine/utils/format_utils.h index 141866378..5dccee359 100644 --- a/engine/utils/format_utils.h +++ b/engine/utils/format_utils.h @@ -46,13 +46,13 @@ inline std::string print_float(const std::string& key, float value) { } else return ""; }; -inline std::string writeKeyValue(const std::string& key, +inline std::string WriteKeyValue(const std::string& key, const YAML::Node& value, const std::string& comment = "") { - std::ostringstream outFile; + std::ostringstream out_file; if (!value) return ""; - outFile << key << ": "; + out_file << key << ": "; // Check if the value is a float and round it to 6 decimal places if (value.IsScalar()) { @@ -66,19 +66,19 @@ inline std::string writeKeyValue(const std::string& key, if (strValue.back() == '.') { strValue.pop_back(); } - outFile << strValue; + out_file << strValue; } catch (const std::exception& e) { - outFile << value; // If not a float, write as is + out_file << value; // If not a float, write as is } } else { - outFile << value; + out_file << value; } if (!comment.empty()) { - outFile << " # " << comment; + out_file << " # " << comment; } - outFile << "\n"; - return outFile.str(); + out_file << "\n"; + return out_file.str(); }; inline std::string BytesToHumanReadable(uint64_t bytes) { From f8c1df69836d96d96e257dc75a0c9a2ee32611b6 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 13 Jan 2025 09:43:50 +0700 Subject: [PATCH 16/16] fix: engine Issues & API Issues (#1811) * chore: convention * fix: correct get remote model list * feat: auto generate remote model config * feat: support update remote engine * fix: do not generate remote model * chore: change engine_name to engine * fix: api key template on engine level * fix: add type for local engine * chore: cleanup * fix: add remote engine to /v1/engines GET * fix: build * fix: load engine when start model * chore: add log * fix: ignore chat_completions in model * fix: delete remote model * fix: replace api_key_template by header_template * fix: use engine from model yaml * fix: better error handling in stream mode * chore: cleanup * chore: unit test for anthropic response --------- Co-authored-by: vansangpfiev --- docs/static/openapi/cortex.json | 16 +- engine/common/engine_servicei.h | 8 +- engine/config/model_config.h | 58 +--- engine/config/remote_template.h | 66 ---- engine/controllers/engines.cc | 164 ++++++++-- engine/controllers/engines.h | 6 + engine/controllers/models.cc | 3 +- engine/controllers/server.cc | 15 +- engine/cortex-common/remote_enginei.h | 4 +- engine/database/engines.h | 2 +- engine/extensions/remote-engine/helper.h | 80 +++++ .../extensions/remote-engine/remote_engine.cc | 298 +++++++++--------- .../extensions/remote-engine/remote_engine.h | 15 +- engine/services/engine_service.cc | 39 ++- engine/services/engine_service.h | 4 +- engine/services/inference_service.cc | 5 + engine/services/inference_service.h | 2 + engine/services/model_service.cc | 31 +- engine/services/model_service.h | 2 + engine/test/components/test_remote_engine.cc | 167 +++++++++- 20 files changed, 663 insertions(+), 322 deletions(-) delete mode 100644 engine/config/remote_template.h create mode 100644 engine/extensions/remote-engine/helper.h diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index d006f0f2d..2deb15e5e 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5388,8 +5388,8 @@ "engine", "version", "inference_params", - "TransformReq", - "TransformResp", + "transform_req", + "transform_resp", "metadata" ], "properties": { @@ -5397,9 +5397,9 @@ "type": "string", "description": "The identifier of the model." }, - "api_key_template": { + "header_template": { "type": "string", - "description": "Template for the API key header." + "description": "Template for the header." }, "engine": { "type": "string", @@ -5432,7 +5432,7 @@ } } }, - "TransformReq": { + "transform_req": { "type": "object", "properties": { "get_models": { @@ -5454,7 +5454,7 @@ } } }, - "TransformResp": { + "transform_resp": { "type": "object", "properties": { "chat_completions": { @@ -6162,9 +6162,9 @@ "description": "Number of GPU layers.", "example": 33 }, - "api_key_template": { + "header_template": { "type": "string", - "description": "Template for the API key header." + "description": "Template for the header." }, "version": { "type": "string", diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index a4b0c8732..ceb9b2fec 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -25,12 +25,14 @@ struct EngineVariantResponse { std::string name; std::string version; std::string engine; + std::string type; Json::Value ToJson() const { Json::Value root; root["name"] = name; root["version"] = version; root["engine"] = engine; + root["type"] = type.empty() ? "local" : type; return root; } }; @@ -57,7 +59,7 @@ class EngineServiceI { virtual cpp::result GetEngineByNameAndVariant( const std::string& engine_name, - const std::optional variant = std::nullopt) = 0; - - virtual bool IsRemoteEngine(const std::string& engine_name) = 0; + const std::optional variant = std::nullopt) const = 0; + + virtual bool IsRemoteEngine(const std::string& engine_name) const = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index ea671354e..1d51cfb01 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -11,7 +11,6 @@ #include #include #include -#include "config/remote_template.h" #include "utils/format_utils.h" #include "utils/remote_models_utils.h" @@ -19,15 +18,15 @@ namespace config { struct RemoteModelConfig { std::string model; - std::string api_key_template; + std::string header_template; std::string engine; std::string version; - std::size_t created; + size_t created; std::string object = "model"; std::string owned_by = ""; Json::Value inference_params; - Json::Value TransformReq; - Json::Value TransformResp; + Json::Value transform_req; + Json::Value transform_resp; Json::Value metadata; void LoadFromJson(const Json::Value& json) { if (!json.isObject()) { @@ -36,8 +35,8 @@ struct RemoteModelConfig { // Load basic string fields model = json.get("model", model).asString(); - api_key_template = - json.get("api_key_template", api_key_template).asString(); + header_template = + json.get("header_template", header_template).asString(); engine = json.get("engine", engine).asString(); version = json.get("version", version).asString(); created = @@ -47,31 +46,8 @@ struct RemoteModelConfig { // Load JSON object fields directly inference_params = json.get("inference_params", inference_params); - TransformReq = json.get("TransformReq", TransformReq); - // Use default template if it is empty, currently we only support 2 remote engines - auto is_anthropic = [](const std::string& model) { - return model.find("claude") != std::string::npos; - }; - if (TransformReq["chat_completions"]["template"].isNull()) { - if (is_anthropic(model)) { - TransformReq["chat_completions"]["template"] = - kAnthropicTransformReqTemplate; - } else { - TransformReq["chat_completions"]["template"] = - kOpenAITransformReqTemplate; - } - } - TransformResp = json.get("TransformResp", TransformResp); - if (TransformResp["chat_completions"]["template"].isNull()) { - if (is_anthropic(model)) { - TransformResp["chat_completions"]["template"] = - kAnthropicTransformRespTemplate; - } else { - TransformResp["chat_completions"]["template"] = - kOpenAITransformRespTemplate; - } - } - + transform_req = json.get("transform_req", transform_req); + transform_resp = json.get("transform_resp", transform_resp); metadata = json.get("metadata", metadata); } @@ -80,7 +56,7 @@ struct RemoteModelConfig { // Add basic string fields json["model"] = model; - json["api_key_template"] = api_key_template; + json["header_template"] = header_template; json["engine"] = engine; json["version"] = version; json["created"] = static_cast(created); @@ -89,8 +65,8 @@ struct RemoteModelConfig { // Add JSON object fields directly json["inference_params"] = inference_params; - json["TransformReq"] = TransformReq; - json["TransformResp"] = TransformResp; + json["transform_req"] = transform_req; + json["transform_resp"] = transform_resp; json["metadata"] = metadata; return json; @@ -101,7 +77,7 @@ struct RemoteModelConfig { // Convert basic fields root["model"] = model; - root["api_key_template"] = api_key_template; + root["header_template"] = header_template; root["engine"] = engine; root["version"] = version; root["object"] = object; @@ -111,8 +87,8 @@ struct RemoteModelConfig { // Convert Json::Value to YAML::Node using utility function root["inference_params"] = remote_models_utils::jsonToYaml(inference_params); - root["TransformReq"] = remote_models_utils::jsonToYaml(TransformReq); - root["TransformResp"] = remote_models_utils::jsonToYaml(TransformResp); + root["transform_req"] = remote_models_utils::jsonToYaml(transform_req); + root["transform_resp"] = remote_models_utils::jsonToYaml(transform_resp); root["metadata"] = remote_models_utils::jsonToYaml(metadata); // Save to file @@ -134,7 +110,7 @@ struct RemoteModelConfig { // Load basic fields model = root["model"].as(""); - api_key_template = root["api_key_template"].as(""); + header_template = root["header_template"].as(""); engine = root["engine"].as(""); version = root["version"] ? root["version"].as() : ""; created = root["created"] ? root["created"].as() : 0; @@ -144,8 +120,8 @@ struct RemoteModelConfig { // Load complex fields using utility function inference_params = remote_models_utils::yamlToJson(root["inference_params"]); - TransformReq = remote_models_utils::yamlToJson(root["TransformReq"]); - TransformResp = remote_models_utils::yamlToJson(root["TransformResp"]); + transform_req = remote_models_utils::yamlToJson(root["transform_req"]); + transform_resp = remote_models_utils::yamlToJson(root["transform_resp"]); metadata = remote_models_utils::yamlToJson(root["metadata"]); } }; diff --git a/engine/config/remote_template.h b/engine/config/remote_template.h deleted file mode 100644 index 8a17aaa9a..000000000 --- a/engine/config/remote_template.h +++ /dev/null @@ -1,66 +0,0 @@ -#include - -namespace config { -const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; -const std::string kAnthropicTransformReqTemplate = - R"({ - {% for key, value in input_request %} - {% if key == "messages" %} - {% if input_request.messages.0.role == "system" %} - "system": "{{ input_request.messages.0.content }}", - "messages": [ - {% for message in input_request.messages %} - {% if not loop.is_first %} - {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endif %} - {% endfor %} - ] - {% else %} - "messages": [ - {% for message in input_request.messages %} - {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endfor %} - ] - {% endif %} - {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} - "{{ key }}": {{ tojson(value) }} - {% endif %} - {% if not loop.is_last %},{% endif %} - {% endfor %} })"; -const std::string kAnthropicTransformRespTemplate = R"({ - "id": "{{ input_request.id }}", - "created": null, - "object": "chat.completion", - "model": "{{ input_request.model }}", - "choices": [ - { - "index": 0, - "message": { - "role": "{{ input_request.role }}", - "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "{{ input_request.stop_reason }}" - } - ], - "usage": { - "prompt_tokens": {{ input_request.usage.input_tokens }}, - "completion_tokens": {{ input_request.usage.output_tokens }}, - "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, - "prompt_tokens_details": { - "cached_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "system_fingerprint": "fp_6b68a8204b" - })"; - -} // namespace config \ No newline at end of file diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 24e61ba4f..8cf98785e 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,7 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" +#include "utils/http_util.h" #include "utils/logging_utils.h" +#include "utils/scope_exit.h" #include "utils/string_utils.h" namespace { @@ -185,21 +187,58 @@ void Engines::InstallEngine( norm_version = version; } - if ((req->getJsonObject()) && - (*(req->getJsonObject())).get("type", "").asString() == "remote") { - auto type = (*(req->getJsonObject())).get("type", "").asString(); - auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); - auto url = (*(req->getJsonObject())).get("url", "").asString(); + auto result = + engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); + if (result.has_error()) { + Json::Value res; + res["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + CTL_INF("Error: " << result.error()); + callback(resp); + } else { + Json::Value res; + res["message"] = "Engine starts installing!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + CTL_INF("Engine starts installing!"); + callback(resp); + } +} + +void Engines::InstallRemoteEngine( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "engine")) { + return; + } + std::optional norm_variant = std::nullopt; + std::string norm_version{"latest"}; + + if (req->getJsonObject() != nullptr) { + auto variant = (*(req->getJsonObject())).get("variant", "").asString(); + auto version = + (*(req->getJsonObject())).get("version", "latest").asString(); + + if (!variant.empty()) { + norm_variant = variant; + } + norm_version = version; + } + + std::string engine; + if (auto o = req->getJsonObject(); o) { + engine = (*o).get("engine", "").asString(); + auto type = (*o).get("type", "").asString(); + auto api_key = (*o).get("api_key", "").asString(); + auto url = (*o).get("url", "").asString(); auto variant = norm_variant.value_or("all-platforms"); - auto status = (*(req->getJsonObject())).get("status", "Default").asString(); + auto status = (*o).get("status", "Default").asString(); std::string metadata; - if ((*(req->getJsonObject())).isMember("metadata") && - (*(req->getJsonObject()))["metadata"].isObject()) { - metadata = (*(req->getJsonObject())) - .get("metadata", Json::Value(Json::objectValue)) - .toStyledString(); - } else if ((*(req->getJsonObject())).isMember("metadata") && - !(*(req->getJsonObject()))["metadata"].isObject()) { + if ((*o).isMember("metadata") && (*o)["metadata"].isObject()) { + metadata = + (*o).get("metadata", Json::Value(Json::objectValue)).toStyledString(); + } else if ((*o).isMember("metadata") && !(*o)["metadata"].isObject()) { Json::Value res; res["message"] = "metadata must be object"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); @@ -208,8 +247,7 @@ void Engines::InstallEngine( return; } - auto get_models_url = (*(req->getJsonObject())) - .get("metadata", Json::Value(Json::objectValue)) + auto get_models_url = (*o).get("metadata", Json::Value(Json::objectValue)) .get("get_models_url", "") .asString(); @@ -262,25 +300,6 @@ void Engines::InstallEngine( resp->setStatusCode(k200OK); callback(resp); } - return; - } - - auto result = - engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); - if (result.has_error()) { - Json::Value res; - res["message"] = result.error(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k400BadRequest); - CTL_INF("Error: " << result.error()); - callback(resp); - } else { - Json::Value res; - res["message"] = "Engine starts installing!"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k200OK); - CTL_INF("Engine starts installing!"); - callback(resp); } } @@ -288,6 +307,24 @@ void Engines::GetInstalledEngineVariants( const HttpRequestPtr& req, std::function&& callback, const std::string& engine) const { + + if (engine_service_->IsRemoteEngine(engine)) { + auto remote_engines = engine_service_->GetEngines(); + Json::Value releases(Json::arrayValue); + if (remote_engines.has_value()) { + for (auto e : remote_engines.value()) { + if (e.type == kRemote && e.engine_name == engine) { + releases.append(e.ToJson()); + break; + } + } + } + auto resp = cortex_utils::CreateCortexHttpJsonResponse(releases); + resp->setStatusCode(k200OK); + callback(resp); + return; + } + auto result = engine_service_->GetInstalledEngineVariants(engine); if (result.has_error()) { Json::Value res; @@ -310,6 +347,65 @@ void Engines::UpdateEngine( const HttpRequestPtr& req, std::function&& callback, const std::string& engine) { + + if (engine_service_->IsRemoteEngine(engine)) { + auto exist_engine = engine_service_->GetEngineByNameAndVariant(engine); + // only allow 1 variant 1 version of a remote engine name + if (!exist_engine) { + Json::Value res; + res["message"] = "Remote engine '" + engine + "' is not installed"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + if (auto o = req->getJsonObject(); o) { + auto type = (*o).get("type", (*exist_engine).type).asString(); + auto api_key = (*o).get("api_key", (*exist_engine).api_key).asString(); + auto url = (*o).get("url", (*exist_engine).url).asString(); + auto status = (*o).get("status", (*exist_engine).status).asString(); + auto version = (*o).get("version", "latest").asString(); + std::string metadata; + if ((*o).isMember("metadata") && (*o)["metadata"].isObject()) { + metadata = (*o).get("metadata", Json::Value(Json::objectValue)) + .toStyledString(); + } else if ((*o).isMember("metadata") && !(*o)["metadata"].isObject()) { + Json::Value res; + res["message"] = "metadata must be object"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } else { + metadata = (*exist_engine).metadata; + } + + auto upd_res = + engine_service_->UpsertEngine(engine, type, api_key, url, version, + "all-platforms", status, metadata); + if (upd_res.has_error()) { + Json::Value res; + res["message"] = upd_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value res; + res["message"] = "Remote Engine update successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + callback(resp); + } + } else { + Json::Value res; + res["message"] = "Request body is empty!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + } + return; + } + auto result = engine_service_->UpdateEngine(engine); if (result.has_error()) { Json::Value res; diff --git a/engine/controllers/engines.h b/engine/controllers/engines.h index 3ad9708e3..78df6ccfb 100644 --- a/engine/controllers/engines.h +++ b/engine/controllers/engines.h @@ -16,6 +16,8 @@ class Engines : public drogon::HttpController { METHOD_ADD(Engines::InstallEngine, "/{1}/install", Options, Post); ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}/install", Options, Post); + METHOD_ADD(Engines::InstallRemoteEngine, "/engines", Options, Post); + ADD_METHOD_TO(Engines::InstallRemoteEngine, "/v1/engines", Options, Post); // uninstall engine METHOD_ADD(Engines::UninstallEngine, "/{1}/install", Options, Delete); @@ -68,6 +70,10 @@ class Engines : public drogon::HttpController { std::function&& callback, const std::string& engine); + void InstallRemoteEngine( + const HttpRequestPtr& req, + std::function&& callback); + void UninstallEngine(const HttpRequestPtr& req, std::function&& callback, const std::string& engine); diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 34c6504ac..5bf02aa46 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -701,11 +701,10 @@ void Models::AddRemoteModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - // TODO: remove hardcode "openai" when engine is finish cortex::db::ModelEntry model_entry{ model_handle, "", "", yaml_rel_path.string(), model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, - "openai"}; + engine_name}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); if (db_service_->AddModelEntry(model_entry).value()) { diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 83eaddb4e..a8cff2166 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -44,6 +44,11 @@ void server::ChatCompletion( } }(); + if (auto efm = inference_svc_->GetEngineByModelId(model_id); !efm.empty()) { + engine_type = efm; + (*json_body)["engine"] = efm; + } + LOG_DEBUG << "request body: " << json_body->toStyledString(); auto q = std::make_shared(); auto ir = inference_svc_->HandleChatCompletion(q, json_body); @@ -203,7 +208,6 @@ void server::RouteRequest( ProcessNonStreamRes(callback, *q); LOG_TRACE << "Done route request"; } - } void server::LoadModel(const HttpRequestPtr& req, @@ -223,7 +227,7 @@ void server::ProcessStreamRes(std::function cb, auto err_or_done = std::make_shared(false); auto chunked_content_provider = [this, q, err_or_done, engine_type, model_id]( char* buf, - std::size_t buf_size) -> std::size_t { + std::size_t buf_size) -> std::size_t { if (buf == nullptr) { LOG_TRACE << "Buf is null"; if (!(*err_or_done)) { @@ -243,7 +247,12 @@ void server::ProcessStreamRes(std::function cb, *err_or_done = true; } - auto str = res["data"].asString(); + std::string str; + if (status["status_code"].asInt() != k200OK) { + str = json_helper::DumpJsonString(res); + } else { + str = res["data"].asString(); + } LOG_DEBUG << "data: " << str; std::size_t n = std::min(str.size(), buf_size); memcpy(buf, str.data(), n); diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h index 81ffbf5cd..835f526a0 100644 --- a/engine/cortex-common/remote_enginei.h +++ b/engine/cortex-common/remote_enginei.h @@ -33,5 +33,7 @@ class RemoteEngineI { std::function&& callback) = 0; // Get available remote models - virtual Json::Value GetRemoteModels() = 0; + virtual Json::Value GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) = 0; }; diff --git a/engine/database/engines.h b/engine/database/engines.h index 7429d0fa2..1312a9c67 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -27,7 +27,7 @@ struct EngineEntry { // Convert basic fields root["id"] = id; - root["engine_name"] = engine_name; + root["engine"] = engine_name; root["type"] = type; root["api_key"] = api_key; root["url"] = url; diff --git a/engine/extensions/remote-engine/helper.h b/engine/extensions/remote-engine/helper.h new file mode 100644 index 000000000..5a99e5f33 --- /dev/null +++ b/engine/extensions/remote-engine/helper.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include + +namespace remote_engine { +std::vector GetReplacements(const std::string& header_template) { + std::vector replacements; + std::regex placeholder_regex(R"(\{\{(.*?)\}\})"); + std::smatch match; + + std::string template_copy = header_template; + while (std::regex_search(template_copy, match, placeholder_regex)) { + std::string key = match[1].str(); + replacements.push_back(key); + template_copy = match.suffix().str(); + } + + return replacements; +} + +std::vector ReplaceHeaderPlaceholders( + const std::string& header_template, + const std::unordered_map& replacements) { + std::vector result; + size_t start = 0; + size_t end = header_template.find("}}"); + + while (end != std::string::npos) { + // Extract the part + std::string part = header_template.substr(start, end - start + 2); + + // Replace variables in this part + for (const auto& var : replacements) { + std::string placeholder = "{{" + var.first + "}}"; + size_t pos = part.find(placeholder); + if (pos != std::string::npos) { + part.replace(pos, placeholder.length(), var.second); + } + } + + // Trim whitespace + part.erase(0, part.find_first_not_of(" \t\n\r\f\v")); + part.erase(part.find_last_not_of(" \t\n\r\f\v") + 1); + + // Add to result if not empty + if (!part.empty()) { + result.push_back(part); + } + + // Move to next part + start = end + 2; + end = header_template.find("}}", start); + } + + // Process any remaining part + if (start < header_template.length()) { + std::string part = header_template.substr(start); + + // Replace variables in this part + for (const auto& var : replacements) { + std::string placeholder = "{{" + var.first + "}}"; + size_t pos = part.find(placeholder); + if (pos != std::string::npos) { + part.replace(pos, placeholder.length(), var.second); + } + } + + // Trim whitespace + part.erase(0, part.find_first_not_of(" \t\n\r\f\v")); + part.erase(part.find_last_not_of(" \t\n\r\f\v") + 1); + + if (!part.empty()) { + result.push_back(part); + } + } + return result; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 6361077dd..0d7ecbef1 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -1,8 +1,10 @@ #include "remote_engine.h" #include #include +#include #include #include +#include "helper.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" namespace remote_engine { @@ -12,13 +14,6 @@ constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; constexpr const int kFileLoggerOption = 0; -bool is_anthropic(const std::string& model) { - return model.find("claude") != std::string::npos; -} - -bool is_openai(const std::string& model) { - return model.find("gpt") != std::string::npos; -} constexpr const std::array kAnthropicModels = { "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", @@ -31,9 +26,20 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { auto* context = static_cast(userdata); std::string chunk(ptr, size * nmemb); + CTL_DBG(chunk); + auto check_error = json_helper::ParseJsonString(chunk); + if (check_error.isMember("error")) { + CTL_WRN(chunk); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = k400BadRequest; + (*context->callback)(std::move(status), std::move(check_error)); + return size * nmemb; + } context->buffer += chunk; - // Process complete lines size_t pos; while ((pos = context->buffer.find('\n')) != std::string::npos) { @@ -59,24 +65,23 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, // Parse the JSON Json::Value chunk_json; - if (!is_openai(context->model)) { - std::string s = line.substr(6); - try { - auto root = json_helper::ParseJsonString(s); - root["model"] = context->model; - root["id"] = context->id; - root["stream"] = true; - auto result = context->renderer.Render(context->stream_template, root); - CTL_DBG(result); - chunk_json["data"] = "data: " + result + "\n\n"; - } catch (const std::exception& e) { - CTL_WRN("JSON parse error: " << e.what()); + std::string s = line; + if (line.size() > 6) + s = line.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + if (root.getMemberNames().empty()) continue; - } - } else { - chunk_json["data"] = line + "\n\n"; + root["model"] = context->model; + root["id"] = context->id; + root["stream"] = true; + auto result = context->renderer.Render(context->stream_template, root); + CTL_DBG(result); + chunk_json["data"] = "data: " + result + "\n\n"; + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); + continue; } - Json::Reader reader; Json::Value status; status["is_done"] = false; @@ -102,17 +107,17 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( return response; } - std::string full_url = - config.transform_req["chat_completions"]["url"].as(); + std::string full_url = chat_url_; - struct curl_slist* headers = nullptr; - if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); + if (config.transform_req["chat_completions"]["url"]) { + full_url = + config.transform_req["chat_completions"]["url"].as(); } + CTL_DBG("full_url: " << full_url); - if (is_anthropic(config.model)) { - std::string v = "anthropic-version: " + config.version; - headers = curl_slist_append(headers, v.c_str()); + struct curl_slist* headers = nullptr; + for (auto const& h : header_) { + headers = curl_slist_append(headers, h.c_str()); } headers = curl_slist_append(headers, "Content-Type: application/json"); @@ -121,6 +126,12 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Connection: keep-alive"); std::string stream_template = chat_res_template_; + if (config.transform_resp["chat_completions"] && + config.transform_resp["chat_completions"]["template"]) { + // Model level overrides engine level + stream_template = + config.transform_resp["chat_completions"]["template"].as(); + } StreamContext context{ std::make_shared>( @@ -174,6 +185,21 @@ std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, return result; } +std::vector ReplaceHeaderPlaceholders( + const std::string& template_str, Json::Value json_body) { + CTL_DBG(template_str); + auto keys = GetReplacements(template_str); + if (keys.empty()) + return std::vector{}; + std::unordered_map replacements; + for (auto const& k : keys) { + if (json_body.isMember(k)) { + replacements.insert({k, json_body[k].asString()}); + } + } + return ReplaceHeaderPlaceholders(template_str, replacements); +} + static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, std::string* data) { data->append(ptr, size * nmemb); @@ -181,7 +207,7 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, } RemoteEngine::RemoteEngine(const std::string& engine_name) - : engine_name_(engine_name) { + : engine_name_(engine_name), q_(1 /*n_parallel*/, engine_name) { curl_global_init(CURL_GLOBAL_ALL); } @@ -199,7 +225,9 @@ RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( return nullptr; } -CurlResponse RemoteEngine::MakeGetModelsRequest() { +CurlResponse RemoteEngine::MakeGetModelsRequest( + const std::string& url, const std::string& api_key, + const std::string& header_template) { CURL* curl = curl_easy_init(); CurlResponse response; @@ -209,13 +237,14 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() { return response; } - std::string full_url = metadata_["get_models_url"].asString(); + std::string api_key_header = + ReplaceApiKeyPlaceholder(header_template, api_key); struct curl_slist* headers = nullptr; - headers = curl_slist_append(headers, api_key_template_.c_str()); + headers = curl_slist_append(headers, api_key_header.c_str()); headers = curl_slist_append(headers, "Content-Type: application/json"); - curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); std::string response_string; @@ -246,18 +275,19 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( response.error_message = "Failed to initialize CURL"; return response; } - std::string full_url = - config.transform_req["chat_completions"]["url"].as(); + std::string full_url = chat_url_; - struct curl_slist* headers = nullptr; - if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); + if (config.transform_req["chat_completions"]["url"]) { + full_url = + config.transform_req["chat_completions"]["url"].as(); } + CTL_DBG("full_url: " << full_url); - if (is_anthropic(config.model)) { - std::string v = "anthropic-version: " + config.version; - headers = curl_slist_append(headers, v.c_str()); + struct curl_slist* headers = nullptr; + for (auto const& h : header_) { + headers = curl_slist_append(headers, h.c_str()); } + headers = curl_slist_append(headers, "Content-Type: application/json"); curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); @@ -286,42 +316,30 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( bool RemoteEngine::LoadModelConfig(const std::string& model, const std::string& yaml_path, - const std::string& api_key) { + const Json::Value& body) { try { YAML::Node config = YAML::LoadFile(yaml_path); ModelConfig model_config; model_config.model = model; - if (is_anthropic(model)) { - if (!config["version"]) { - CTL_ERR("Missing version for model: " << model); - return false; - } - model_config.version = config["version"].as(); - } - - // Required fields - if (!config["api_key_template"]) { - LOG_ERROR << "Missing required fields in config for model " << model; - return false; - } - - model_config.api_key = api_key; + model_config.api_key = body["api_key"].asString(); // model_config.url = ; // Optional fields - if (config["api_key_template"]) { - api_key_template_ = ReplaceApiKeyPlaceholder( - config["api_key_template"].as(), api_key); + if (auto s = config["header_template"]; s && !s.as().empty()) { + header_ = ReplaceHeaderPlaceholders(s.as(), body); + for (auto const& h : header_) { + CTL_DBG("header: " << h); + } } - if (config["TransformReq"]) { - model_config.transform_req = config["TransformReq"]; + if (config["transform_req"]) { + model_config.transform_req = config["transform_req"]; } else { - LOG_WARN << "Missing TransformReq in config for model " << model; + LOG_WARN << "Missing transform_req in config for model " << model; } - if (config["TransformResp"]) { - model_config.transform_resp = config["TransformResp"]; + if (config["transform_resp"]) { + model_config.transform_resp = config["transform_resp"]; } else { - LOG_WARN << "Missing TransformResp in config for model " << model; + LOG_WARN << "Missing transform_resp in config for model " << model; } model_config.is_loaded = true; @@ -393,34 +411,54 @@ void RemoteEngine::LoadModel( const std::string& model_path = (*json_body)["model_path"].asString(); const std::string& api_key = (*json_body)["api_key"].asString(); - if (!LoadModelConfig(model, model_path, api_key)) { - Json::Value error; - error["error"] = "Failed to load model configuration"; - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - callback(std::move(status), std::move(error)); - return; - } if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; - if (!metadata_["TransformReq"].isNull() && - !metadata_["TransformReq"]["chat_completions"].isNull() && - !metadata_["TransformReq"]["chat_completions"]["template"].isNull()) { + if (!metadata_["transform_req"].isNull() && + !metadata_["transform_req"]["chat_completions"].isNull() && + !metadata_["transform_req"]["chat_completions"]["template"].isNull()) { chat_req_template_ = - metadata_["TransformReq"]["chat_completions"]["template"].asString(); + metadata_["transform_req"]["chat_completions"]["template"].asString(); CTL_INF(chat_req_template_); } - if (!metadata_["TransformResp"].isNull() && - !metadata_["TransformResp"]["chat_completions"].isNull() && - !metadata_["TransformResp"]["chat_completions"]["template"].isNull()) { + if (!metadata_["transform_resp"].isNull() && + !metadata_["transform_resp"]["chat_completions"].isNull() && + !metadata_["transform_resp"]["chat_completions"]["template"].isNull()) { chat_res_template_ = - metadata_["TransformResp"]["chat_completions"]["template"].asString(); + metadata_["transform_resp"]["chat_completions"]["template"] + .asString(); CTL_INF(chat_res_template_); } + + if (!metadata_["transform_req"].isNull() && + !metadata_["transform_req"]["chat_completions"].isNull() && + !metadata_["transform_req"]["chat_completions"]["url"].isNull()) { + chat_url_ = + metadata_["transform_req"]["chat_completions"]["url"].asString(); + CTL_INF(chat_url_); + } + } + + if (json_body->isMember("metadata")) { + if (!metadata_["header_template"].isNull()) { + header_ = ReplaceHeaderPlaceholders( + metadata_["header_template"].asString(), *json_body); + for (auto const& h : header_) { + CTL_DBG("header: " << h); + } + } + } + + if (!LoadModelConfig(model, model_path, *json_body)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; } Json::Value response; @@ -501,15 +539,6 @@ void RemoteEngine::HandleChatCompletion( // Transform request std::string result; try { - // Check if required YAML nodes exist - if (!model_config->transform_req["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_req"); - } - if (!model_config->transform_req["chat_completions"]["template"]) { - throw std::runtime_error("Missing 'template' node in chat_completions"); - } - // Validate JSON body if (!json_body || json_body->isNull()) { throw std::runtime_error("Invalid or null JSON body"); @@ -517,12 +546,16 @@ void RemoteEngine::HandleChatCompletion( // Get template string with error check std::string template_str; - try { + if (!chat_req_template_.empty()) { + CTL_DBG("Use engine transform request template: " << chat_req_template_); + template_str = chat_req_template_; + } + if (model_config->transform_req["chat_completions"] && + model_config->transform_req["chat_completions"]["template"]) { + // Model level overrides engine level template_str = model_config->transform_req["chat_completions"]["template"] .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error("Failed to convert template node to string: " + - std::string(e.what())); + CTL_DBG("Use model transform request template: " << template_str); } // Render with error handling @@ -540,7 +573,9 @@ void RemoteEngine::HandleChatCompletion( } if (is_stream) { - MakeStreamingChatCompletionRequest(*model_config, result, callback); + q_.runTaskInQueue([this, model_config, result, cb = std::move(callback)] { + MakeStreamingChatCompletionRequest(*model_config, result, cb); + }); } else { auto response = MakeChatCompletionRequest(*model_config, result); @@ -579,33 +614,14 @@ void RemoteEngine::HandleChatCompletion( CTL_DBG( "Use engine transform response template: " << chat_res_template_); template_str = chat_res_template_; - } else { - // Check if required YAML nodes exist - if (!model_config->transform_resp["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_resp"); - } - if (!model_config->transform_resp["chat_completions"]["template"]) { - throw std::runtime_error( - "Missing 'template' node in chat_completions"); - } - - // Validate JSON body - if (!response_json || response_json.isNull()) { - throw std::runtime_error("Invalid or null JSON body"); - } - - // Get template string with error check - - try { - template_str = - model_config->transform_resp["chat_completions"]["template"] - .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error( - "Failed to convert template node to string: " + - std::string(e.what())); - } + } + if (model_config->transform_resp["chat_completions"] && + model_config->transform_resp["chat_completions"]["template"]) { + // Model level overrides engine level + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + CTL_DBG("Use model transform request template: " << template_str); } try { @@ -686,9 +702,10 @@ void RemoteEngine::HandleEmbedding( callback(Json::Value(), Json::Value()); } -Json::Value RemoteEngine::GetRemoteModels() { - if (metadata_["get_models_url"].isNull() || - metadata_["get_models_url"].asString().empty()) { +Json::Value RemoteEngine::GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) { + if (url.empty()) { if (engine_name_ == kAnthropicEngine) { Json::Value json_resp; Json::Value model_array(Json::arrayValue); @@ -709,20 +726,19 @@ Json::Value RemoteEngine::GetRemoteModels() { return Json::Value(); } } else { - auto response = MakeGetModelsRequest(); + auto response = MakeGetModelsRequest(url, api_key, header_template); if (response.error) { Json::Value error; error["error"] = response.error_message; + CTL_WRN(response.error_message); return error; } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; + CTL_DBG(response.body); + auto body_json = json_helper::ParseJsonString(response.body); + if (body_json.isMember("error")) { + return body_json["error"]; } - return response_json; + return body_json; } } diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 6f08b5403..bc6d534c5 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -9,6 +9,7 @@ #include #include "cortex-common/remote_enginei.h" #include "extensions/template_renderer.h" +#include "trantor/utils/ConcurrentTaskQueue.h" #include "utils/engine_constants.h" #include "utils/file_logger.h" // Helper for CURL response @@ -50,8 +51,10 @@ class RemoteEngine : public RemoteEngineI { Json::Value metadata_; std::string chat_req_template_; std::string chat_res_template_; - std::string api_key_template_; + std::vector header_; std::string engine_name_; + std::string chat_url_; + trantor::ConcurrentTaskQueue q_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, @@ -60,11 +63,13 @@ class RemoteEngine : public RemoteEngineI { CurlResponse MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, const std::function& callback); - CurlResponse MakeGetModelsRequest(); + CurlResponse MakeGetModelsRequest(const std::string& url, + const std::string& api_key, + const std::string& header_template); // Internal model management bool LoadModelConfig(const std::string& model, const std::string& yaml_path, - const std::string& api_key); + const Json::Value& body); ModelConfig* GetModelConfig(const std::string& model); public: @@ -97,7 +102,9 @@ class RemoteEngine : public RemoteEngineI { std::shared_ptr json_body, std::function&& callback) override; - Json::Value GetRemoteModels() override; + Json::Value GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) override; }; } // namespace remote_engine \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c6b107af3..a80c5fe60 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -6,10 +6,10 @@ #include #include "algorithm" +#include "config/model_config.h" #include "database/engines.h" - +#include "database/models.h" #include "extensions/python-engine/python_engine.h" - #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" @@ -736,9 +736,11 @@ cpp::result EngineService::LoadEngine( return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); - - CTL_INF("Loaded engine: " << engine_name); + if (!IsEngineLoaded(engine_name)) { + engines_[engine_name].engine = + new remote_engine::RemoteEngine(engine_name); + CTL_INF("Loaded engine: " << engine_name); + } return {}; } @@ -1074,7 +1076,8 @@ cpp::result EngineService::GetEngineById( cpp::result EngineService::GetEngineByNameAndVariant( - const std::string& engine_name, const std::optional variant) { + const std::string& engine_name, + const std::optional variant) const { assert(db_service_); auto get_res = db_service_->GetEngineByNameAndVariant(engine_name, variant); @@ -1123,17 +1126,29 @@ cpp::result EngineService::GetRemoteModels( return cpp::fail(r.error()); } + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + if (!IsEngineLoaded(engine_name)) { - auto exist_engine = GetEngineByNameAndVariant(engine_name); - if (exist_engine.has_error()) { - return cpp::fail("Remote engine '" + engine_name + "' is not installed"); - } engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); } + auto remote_engine_json = exist_engine.value().ToJson(); auto& e = std::get(engines_[engine_name].engine); - auto res = e->GetRemoteModels(); + auto url = remote_engine_json["metadata"]["get_models_url"].asString(); + auto api_key = remote_engine_json["api_key"].asString(); + auto header_template = + remote_engine_json["metadata"]["header_template"].asString(); + if (url.empty()) + CTL_WRN("url is empty"); + if (api_key.empty()) + CTL_WRN("api_key is empty"); + if (header_template.empty()) + CTL_WRN("header_template is empty"); + auto res = e->GetRemoteModels(url, api_key, header_template); if (!res["error"].isNull()) { return cpp::fail(res["error"].asString()); } else { @@ -1141,7 +1156,7 @@ cpp::result EngineService::GetRemoteModels( } } -bool EngineService::IsRemoteEngine(const std::string& engine_name) { +bool EngineService::IsRemoteEngine(const std::string& engine_name) const { auto ne = Repo2Engine(engine_name); auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; for (auto const& le : local_engines) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index a460582c6..f98037bab 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -139,7 +139,7 @@ class EngineService : public EngineServiceI { cpp::result GetEngineByNameAndVariant( const std::string& engine_name, - const std::optional variant = std::nullopt) override; + const std::optional variant = std::nullopt) const override; cpp::result UpsertEngine( const std::string& engine_name, const std::string& type, @@ -155,7 +155,7 @@ class EngineService : public EngineServiceI { void RegisterEngineLibPath(); - bool IsRemoteEngine(const std::string& engine_name) override; + bool IsRemoteEngine(const std::string& engine_name) const override; private: bool IsEngineLoaded(const std::string& engine); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 3668fb6fe..057b6f716 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -394,3 +394,8 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } + +std::string InferenceService::GetEngineByModelId( + const std::string& model_id) const { + return model_service_.lock()->GetEngineByModelId(model_id); +} diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index f23be3f23..794110f99 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -69,6 +69,8 @@ class InferenceService { model_service_ = model_service; } + std::string GetEngineByModelId(const std::string& model_id) const; + private: std::shared_ptr engine_service_; std::weak_ptr model_service_; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 74767a9b2..f79c20859 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -768,7 +768,8 @@ cpp::result ModelService::DeleteModel( // Remove yaml file std::filesystem::remove(yaml_fp); // Remove model files if they are not imported locally - if (model_entry.value().branch_name != "imported") { + if (model_entry.value().branch_name != "imported" && + !engine_svc_->IsRemoteEngine(mc.engine)) { if (mc.files.size() > 0) { if (mc.engine == kLlamaRepo || mc.engine == kLlamaEngine) { for (auto& file : mc.files) { @@ -890,7 +891,7 @@ cpp::result ModelService::StartModel( // Running remote model if (engine_svc_->IsRemoteEngine(mc.engine)) { - + engine_svc_->LoadEngine(mc.engine); config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( fmu::ToAbsoluteCortexDataPath( @@ -906,6 +907,10 @@ cpp::result ModelService::StartModel( json_data = remote_mc.ToJson(); json_data["api_key"] = std::move(remote_engine_json["api_key"]); + if (auto v = remote_engine_json["version"].asString(); + !v.empty() && v != "latest") { + json_data["version"] = v; + } json_data["model_path"] = fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) @@ -1374,5 +1379,27 @@ ModelService::GetModelMetadata(const std::string& model_id) const { std::shared_ptr ModelService::GetCachedModelMetadata( const std::string& model_id) const { + if (loaded_model_metadata_map_.find(model_id) == + loaded_model_metadata_map_.end()) + return nullptr; return loaded_model_metadata_map_.at(model_id); } + +std::string ModelService::GetEngineByModelId( + const std::string& model_id) const { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + auto model_entry = db_service_->GetModelInfo(model_id); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return ""; + } + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + CTL_DBG(mc.engine); + return mc.engine; +} \ No newline at end of file diff --git a/engine/services/model_service.h b/engine/services/model_service.h index cc659fea5..a668b27ba 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -96,6 +96,8 @@ class ModelService { std::shared_ptr GetCachedModelMetadata( const std::string& model_id) const; + std::string GetEngineByModelId(const std::string& model_id) const; + private: /** * Handle downloading model which have following pattern: author/model_name diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 5f1b85044..3bb3cdca3 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -1,3 +1,7 @@ +#include +#include +#include +#include "extensions/remote-engine/helper.h" #include "extensions/template_renderer.h" #include "gtest/gtest.h" #include "utils/json_helper.h" @@ -25,19 +29,22 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { {% endfor %} ] {% endif %} + {% if not loop.is_last %},{% endif %} {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} "{{ key }}": {{ tojson(value) }} + {% if not loop.is_last %},{% endif %} {% endif %} - {% if not loop.is_last %},{% endif %} {% endfor %} })"; { std::string message_with_system = R"({ + "engine" : "anthropic", + "max_tokens" : 1024, "messages": [ {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, {"role": "user", "content": "Hello, world"} ], "model": "claude-3-5-sonnet-20241022", - "max_tokens": 1024, + "stream" : true })"; auto data = json_helper::ParseJsonString(message_with_system); @@ -78,4 +85,160 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { EXPECT_EQ(data["messages"][0]["content"].asString(), res_json["messages"][0]["content"].asString()); } +} + +TEST_F(RemoteEngineTest, OpenAiResponse) { + std::string tpl = R"({ + {% set first = true %} + {% for key, value in input_request %} + {% if key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "stream" or key == "object" or key == "usage" %} + {% if not first %},{% endif %} + "{{ key }}": {{ tojson(value) }} + {% set first = false %} + {% endif %} + {% endfor %} + })"; + std::string message = R"( + { + "choices": [ + { + "delta": { + "content": " questions" + }, + "finish_reason": null, + "index": 0 + } + ], + "created": 1735372587, + "id": "", + "model": "o1-preview", + "object": "chat.completion.chunk", + "stream": true, + "system_fingerprint": "fp_1ddf0263de" + })"; + auto data = json_helper::ParseJsonString(message); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["created"].asInt(), res_json["created"].asInt()); + EXPECT_EQ(data["choices"][0]["delta"]["content"].asString(), + res_json["choices"][0]["delta"]["content"].asString()); +} + +TEST_F(RemoteEngineTest, AnthropicResponse) { + std::string tpl = R"( + {% if input_request.stream %} + {"object": "chat.completion.chunk", "model": "{{ input_request.model }}", "choices": [{"index": 0, "delta": { {% if input_request.type == "message_start" %} "role": "assistant", "content": null {% else if input_request.type == "ping" %} "role": "assistant", "content": null {% else if input_request.type == "content_block_delta" %} "role": "assistant", "content": "{{ input_request.delta.text }}" {% else if input_request.type == "content_block_stop" %} "role": "assistant", "content": null {% else if input_request.type == "content_block_stop" %} "role": "assistant", "content": null {% endif %} }, {% if input_request.type == "content_block_stop" %} "finish_reason": "stop" {% else %} "finish_reason": null {% endif %} }]} + {% else %} + {"id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [{ + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": {% if not input_request.content %} null {% else if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% else %} null {% endif %}, + "refusal": null }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" } ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { "cached_tokens": 0 }, + "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, + "system_fingerprint": "fp_6b68a8204b"} + {% endif %})"; + std::string message = R"( + { + "content": [], + "id": "msg_01SckpnDyChcmmawQsWHr8CH", + "model": "claude-3-opus-20240229", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "stream": false, + "type": "message", + "usage": { + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "input_tokens": 130, + "output_tokens": 3 + } + })"; + auto data = json_helper::ParseJsonString(message); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["created"].asInt(), res_json["created"].asInt()); + EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull()); +} + +TEST_F(RemoteEngineTest, HeaderTemplate) { + { + std::string header_template = + R"(x-api-key: {{api_key}} anthropic-version: {{version}})"; + Json::Value test_value; + test_value["api_key"] = "test"; + test_value["version"] = "test_version"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "x-api-key: test"); + EXPECT_EQ(result[1], "anthropic-version: test_version"); + } + + { + std::string header_template = + R"(x-api-key: {{api_key}} anthropic-version: test_version)"; + Json::Value test_value; + test_value["api_key"] = "test"; + test_value["version"] = "test_version"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "x-api-key: test"); + EXPECT_EQ(result[1], "anthropic-version: test_version"); + } + + { + std::string header_template = R"(Authorization: Bearer {{api_key}}")"; + Json::Value test_value; + test_value["api_key"] = "test"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "Authorization: Bearer test"); + } } \ No newline at end of file