diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index d006f0f2d..2deb15e5e 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5388,8 +5388,8 @@ "engine", "version", "inference_params", - "TransformReq", - "TransformResp", + "transform_req", + "transform_resp", "metadata" ], "properties": { @@ -5397,9 +5397,9 @@ "type": "string", "description": "The identifier of the model." }, - "api_key_template": { + "header_template": { "type": "string", - "description": "Template for the API key header." + "description": "Template for the header." }, "engine": { "type": "string", @@ -5432,7 +5432,7 @@ } } }, - "TransformReq": { + "transform_req": { "type": "object", "properties": { "get_models": { @@ -5454,7 +5454,7 @@ } } }, - "TransformResp": { + "transform_resp": { "type": "object", "properties": { "chat_completions": { @@ -6162,9 +6162,9 @@ "description": "Number of GPU layers.", "example": 33 }, - "api_key_template": { + "header_template": { "type": "string", - "description": "Template for the API key header." + "description": "Template for the header." }, "version": { "type": "string", diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index a4b0c8732..ceb9b2fec 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -25,12 +25,14 @@ struct EngineVariantResponse { std::string name; std::string version; std::string engine; + std::string type; Json::Value ToJson() const { Json::Value root; root["name"] = name; root["version"] = version; root["engine"] = engine; + root["type"] = type.empty() ? "local" : type; return root; } }; @@ -57,7 +59,7 @@ class EngineServiceI { virtual cpp::result GetEngineByNameAndVariant( const std::string& engine_name, - const std::optional variant = std::nullopt) = 0; - - virtual bool IsRemoteEngine(const std::string& engine_name) = 0; + const std::optional variant = std::nullopt) const = 0; + + virtual bool IsRemoteEngine(const std::string& engine_name) const = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index ea671354e..1d51cfb01 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -11,7 +11,6 @@ #include #include #include -#include "config/remote_template.h" #include "utils/format_utils.h" #include "utils/remote_models_utils.h" @@ -19,15 +18,15 @@ namespace config { struct RemoteModelConfig { std::string model; - std::string api_key_template; + std::string header_template; std::string engine; std::string version; - std::size_t created; + size_t created; std::string object = "model"; std::string owned_by = ""; Json::Value inference_params; - Json::Value TransformReq; - Json::Value TransformResp; + Json::Value transform_req; + Json::Value transform_resp; Json::Value metadata; void LoadFromJson(const Json::Value& json) { if (!json.isObject()) { @@ -36,8 +35,8 @@ struct RemoteModelConfig { // Load basic string fields model = json.get("model", model).asString(); - api_key_template = - json.get("api_key_template", api_key_template).asString(); + header_template = + json.get("header_template", header_template).asString(); engine = json.get("engine", engine).asString(); version = json.get("version", version).asString(); created = @@ -47,31 +46,8 @@ struct RemoteModelConfig { // Load JSON object fields directly inference_params = json.get("inference_params", inference_params); - TransformReq = json.get("TransformReq", TransformReq); - // Use default template if it is empty, currently we only support 2 remote engines - auto is_anthropic = [](const std::string& model) { - return model.find("claude") != std::string::npos; - }; - if (TransformReq["chat_completions"]["template"].isNull()) { - if (is_anthropic(model)) { - TransformReq["chat_completions"]["template"] = - kAnthropicTransformReqTemplate; - } else { - TransformReq["chat_completions"]["template"] = - kOpenAITransformReqTemplate; - } - } - TransformResp = json.get("TransformResp", TransformResp); - if (TransformResp["chat_completions"]["template"].isNull()) { - if (is_anthropic(model)) { - TransformResp["chat_completions"]["template"] = - kAnthropicTransformRespTemplate; - } else { - TransformResp["chat_completions"]["template"] = - kOpenAITransformRespTemplate; - } - } - + transform_req = json.get("transform_req", transform_req); + transform_resp = json.get("transform_resp", transform_resp); metadata = json.get("metadata", metadata); } @@ -80,7 +56,7 @@ struct RemoteModelConfig { // Add basic string fields json["model"] = model; - json["api_key_template"] = api_key_template; + json["header_template"] = header_template; json["engine"] = engine; json["version"] = version; json["created"] = static_cast(created); @@ -89,8 +65,8 @@ struct RemoteModelConfig { // Add JSON object fields directly json["inference_params"] = inference_params; - json["TransformReq"] = TransformReq; - json["TransformResp"] = TransformResp; + json["transform_req"] = transform_req; + json["transform_resp"] = transform_resp; json["metadata"] = metadata; return json; @@ -101,7 +77,7 @@ struct RemoteModelConfig { // Convert basic fields root["model"] = model; - root["api_key_template"] = api_key_template; + root["header_template"] = header_template; root["engine"] = engine; root["version"] = version; root["object"] = object; @@ -111,8 +87,8 @@ struct RemoteModelConfig { // Convert Json::Value to YAML::Node using utility function root["inference_params"] = remote_models_utils::jsonToYaml(inference_params); - root["TransformReq"] = remote_models_utils::jsonToYaml(TransformReq); - root["TransformResp"] = remote_models_utils::jsonToYaml(TransformResp); + root["transform_req"] = remote_models_utils::jsonToYaml(transform_req); + root["transform_resp"] = remote_models_utils::jsonToYaml(transform_resp); root["metadata"] = remote_models_utils::jsonToYaml(metadata); // Save to file @@ -134,7 +110,7 @@ struct RemoteModelConfig { // Load basic fields model = root["model"].as(""); - api_key_template = root["api_key_template"].as(""); + header_template = root["header_template"].as(""); engine = root["engine"].as(""); version = root["version"] ? root["version"].as() : ""; created = root["created"] ? root["created"].as() : 0; @@ -144,8 +120,8 @@ struct RemoteModelConfig { // Load complex fields using utility function inference_params = remote_models_utils::yamlToJson(root["inference_params"]); - TransformReq = remote_models_utils::yamlToJson(root["TransformReq"]); - TransformResp = remote_models_utils::yamlToJson(root["TransformResp"]); + transform_req = remote_models_utils::yamlToJson(root["transform_req"]); + transform_resp = remote_models_utils::yamlToJson(root["transform_resp"]); metadata = remote_models_utils::yamlToJson(root["metadata"]); } }; diff --git a/engine/config/remote_template.h b/engine/config/remote_template.h deleted file mode 100644 index 8a17aaa9a..000000000 --- a/engine/config/remote_template.h +++ /dev/null @@ -1,66 +0,0 @@ -#include - -namespace config { -const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; -const std::string kAnthropicTransformReqTemplate = - R"({ - {% for key, value in input_request %} - {% if key == "messages" %} - {% if input_request.messages.0.role == "system" %} - "system": "{{ input_request.messages.0.content }}", - "messages": [ - {% for message in input_request.messages %} - {% if not loop.is_first %} - {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endif %} - {% endfor %} - ] - {% else %} - "messages": [ - {% for message in input_request.messages %} - {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} - {% endfor %} - ] - {% endif %} - {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} - "{{ key }}": {{ tojson(value) }} - {% endif %} - {% if not loop.is_last %},{% endif %} - {% endfor %} })"; -const std::string kAnthropicTransformRespTemplate = R"({ - "id": "{{ input_request.id }}", - "created": null, - "object": "chat.completion", - "model": "{{ input_request.model }}", - "choices": [ - { - "index": 0, - "message": { - "role": "{{ input_request.role }}", - "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "{{ input_request.stop_reason }}" - } - ], - "usage": { - "prompt_tokens": {{ input_request.usage.input_tokens }}, - "completion_tokens": {{ input_request.usage.output_tokens }}, - "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, - "prompt_tokens_details": { - "cached_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "system_fingerprint": "fp_6b68a8204b" - })"; - -} // namespace config \ No newline at end of file diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 24e61ba4f..8cf98785e 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,7 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" +#include "utils/http_util.h" #include "utils/logging_utils.h" +#include "utils/scope_exit.h" #include "utils/string_utils.h" namespace { @@ -185,21 +187,58 @@ void Engines::InstallEngine( norm_version = version; } - if ((req->getJsonObject()) && - (*(req->getJsonObject())).get("type", "").asString() == "remote") { - auto type = (*(req->getJsonObject())).get("type", "").asString(); - auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); - auto url = (*(req->getJsonObject())).get("url", "").asString(); + auto result = + engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); + if (result.has_error()) { + Json::Value res; + res["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + CTL_INF("Error: " << result.error()); + callback(resp); + } else { + Json::Value res; + res["message"] = "Engine starts installing!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + CTL_INF("Engine starts installing!"); + callback(resp); + } +} + +void Engines::InstallRemoteEngine( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "engine")) { + return; + } + std::optional norm_variant = std::nullopt; + std::string norm_version{"latest"}; + + if (req->getJsonObject() != nullptr) { + auto variant = (*(req->getJsonObject())).get("variant", "").asString(); + auto version = + (*(req->getJsonObject())).get("version", "latest").asString(); + + if (!variant.empty()) { + norm_variant = variant; + } + norm_version = version; + } + + std::string engine; + if (auto o = req->getJsonObject(); o) { + engine = (*o).get("engine", "").asString(); + auto type = (*o).get("type", "").asString(); + auto api_key = (*o).get("api_key", "").asString(); + auto url = (*o).get("url", "").asString(); auto variant = norm_variant.value_or("all-platforms"); - auto status = (*(req->getJsonObject())).get("status", "Default").asString(); + auto status = (*o).get("status", "Default").asString(); std::string metadata; - if ((*(req->getJsonObject())).isMember("metadata") && - (*(req->getJsonObject()))["metadata"].isObject()) { - metadata = (*(req->getJsonObject())) - .get("metadata", Json::Value(Json::objectValue)) - .toStyledString(); - } else if ((*(req->getJsonObject())).isMember("metadata") && - !(*(req->getJsonObject()))["metadata"].isObject()) { + if ((*o).isMember("metadata") && (*o)["metadata"].isObject()) { + metadata = + (*o).get("metadata", Json::Value(Json::objectValue)).toStyledString(); + } else if ((*o).isMember("metadata") && !(*o)["metadata"].isObject()) { Json::Value res; res["message"] = "metadata must be object"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); @@ -208,8 +247,7 @@ void Engines::InstallEngine( return; } - auto get_models_url = (*(req->getJsonObject())) - .get("metadata", Json::Value(Json::objectValue)) + auto get_models_url = (*o).get("metadata", Json::Value(Json::objectValue)) .get("get_models_url", "") .asString(); @@ -262,25 +300,6 @@ void Engines::InstallEngine( resp->setStatusCode(k200OK); callback(resp); } - return; - } - - auto result = - engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); - if (result.has_error()) { - Json::Value res; - res["message"] = result.error(); - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k400BadRequest); - CTL_INF("Error: " << result.error()); - callback(resp); - } else { - Json::Value res; - res["message"] = "Engine starts installing!"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); - resp->setStatusCode(k200OK); - CTL_INF("Engine starts installing!"); - callback(resp); } } @@ -288,6 +307,24 @@ void Engines::GetInstalledEngineVariants( const HttpRequestPtr& req, std::function&& callback, const std::string& engine) const { + + if (engine_service_->IsRemoteEngine(engine)) { + auto remote_engines = engine_service_->GetEngines(); + Json::Value releases(Json::arrayValue); + if (remote_engines.has_value()) { + for (auto e : remote_engines.value()) { + if (e.type == kRemote && e.engine_name == engine) { + releases.append(e.ToJson()); + break; + } + } + } + auto resp = cortex_utils::CreateCortexHttpJsonResponse(releases); + resp->setStatusCode(k200OK); + callback(resp); + return; + } + auto result = engine_service_->GetInstalledEngineVariants(engine); if (result.has_error()) { Json::Value res; @@ -310,6 +347,65 @@ void Engines::UpdateEngine( const HttpRequestPtr& req, std::function&& callback, const std::string& engine) { + + if (engine_service_->IsRemoteEngine(engine)) { + auto exist_engine = engine_service_->GetEngineByNameAndVariant(engine); + // only allow 1 variant 1 version of a remote engine name + if (!exist_engine) { + Json::Value res; + res["message"] = "Remote engine '" + engine + "' is not installed"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + if (auto o = req->getJsonObject(); o) { + auto type = (*o).get("type", (*exist_engine).type).asString(); + auto api_key = (*o).get("api_key", (*exist_engine).api_key).asString(); + auto url = (*o).get("url", (*exist_engine).url).asString(); + auto status = (*o).get("status", (*exist_engine).status).asString(); + auto version = (*o).get("version", "latest").asString(); + std::string metadata; + if ((*o).isMember("metadata") && (*o)["metadata"].isObject()) { + metadata = (*o).get("metadata", Json::Value(Json::objectValue)) + .toStyledString(); + } else if ((*o).isMember("metadata") && !(*o)["metadata"].isObject()) { + Json::Value res; + res["message"] = "metadata must be object"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } else { + metadata = (*exist_engine).metadata; + } + + auto upd_res = + engine_service_->UpsertEngine(engine, type, api_key, url, version, + "all-platforms", status, metadata); + if (upd_res.has_error()) { + Json::Value res; + res["message"] = upd_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value res; + res["message"] = "Remote Engine update successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + callback(resp); + } + } else { + Json::Value res; + res["message"] = "Request body is empty!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + } + return; + } + auto result = engine_service_->UpdateEngine(engine); if (result.has_error()) { Json::Value res; diff --git a/engine/controllers/engines.h b/engine/controllers/engines.h index 3ad9708e3..78df6ccfb 100644 --- a/engine/controllers/engines.h +++ b/engine/controllers/engines.h @@ -16,6 +16,8 @@ class Engines : public drogon::HttpController { METHOD_ADD(Engines::InstallEngine, "/{1}/install", Options, Post); ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}/install", Options, Post); + METHOD_ADD(Engines::InstallRemoteEngine, "/engines", Options, Post); + ADD_METHOD_TO(Engines::InstallRemoteEngine, "/v1/engines", Options, Post); // uninstall engine METHOD_ADD(Engines::UninstallEngine, "/{1}/install", Options, Delete); @@ -68,6 +70,10 @@ class Engines : public drogon::HttpController { std::function&& callback, const std::string& engine); + void InstallRemoteEngine( + const HttpRequestPtr& req, + std::function&& callback); + void UninstallEngine(const HttpRequestPtr& req, std::function&& callback, const std::string& engine); diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 34c6504ac..5bf02aa46 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -701,11 +701,10 @@ void Models::AddRemoteModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - // TODO: remove hardcode "openai" when engine is finish cortex::db::ModelEntry model_entry{ model_handle, "", "", yaml_rel_path.string(), model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, - "openai"}; + engine_name}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); if (db_service_->AddModelEntry(model_entry).value()) { diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 83eaddb4e..a8cff2166 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -44,6 +44,11 @@ void server::ChatCompletion( } }(); + if (auto efm = inference_svc_->GetEngineByModelId(model_id); !efm.empty()) { + engine_type = efm; + (*json_body)["engine"] = efm; + } + LOG_DEBUG << "request body: " << json_body->toStyledString(); auto q = std::make_shared(); auto ir = inference_svc_->HandleChatCompletion(q, json_body); @@ -203,7 +208,6 @@ void server::RouteRequest( ProcessNonStreamRes(callback, *q); LOG_TRACE << "Done route request"; } - } void server::LoadModel(const HttpRequestPtr& req, @@ -223,7 +227,7 @@ void server::ProcessStreamRes(std::function cb, auto err_or_done = std::make_shared(false); auto chunked_content_provider = [this, q, err_or_done, engine_type, model_id]( char* buf, - std::size_t buf_size) -> std::size_t { + std::size_t buf_size) -> std::size_t { if (buf == nullptr) { LOG_TRACE << "Buf is null"; if (!(*err_or_done)) { @@ -243,7 +247,12 @@ void server::ProcessStreamRes(std::function cb, *err_or_done = true; } - auto str = res["data"].asString(); + std::string str; + if (status["status_code"].asInt() != k200OK) { + str = json_helper::DumpJsonString(res); + } else { + str = res["data"].asString(); + } LOG_DEBUG << "data: " << str; std::size_t n = std::min(str.size(), buf_size); memcpy(buf, str.data(), n); diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h index 81ffbf5cd..835f526a0 100644 --- a/engine/cortex-common/remote_enginei.h +++ b/engine/cortex-common/remote_enginei.h @@ -33,5 +33,7 @@ class RemoteEngineI { std::function&& callback) = 0; // Get available remote models - virtual Json::Value GetRemoteModels() = 0; + virtual Json::Value GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) = 0; }; diff --git a/engine/database/engines.h b/engine/database/engines.h index 7429d0fa2..1312a9c67 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -27,7 +27,7 @@ struct EngineEntry { // Convert basic fields root["id"] = id; - root["engine_name"] = engine_name; + root["engine"] = engine_name; root["type"] = type; root["api_key"] = api_key; root["url"] = url; diff --git a/engine/extensions/remote-engine/helper.h b/engine/extensions/remote-engine/helper.h new file mode 100644 index 000000000..5a99e5f33 --- /dev/null +++ b/engine/extensions/remote-engine/helper.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include + +namespace remote_engine { +std::vector GetReplacements(const std::string& header_template) { + std::vector replacements; + std::regex placeholder_regex(R"(\{\{(.*?)\}\})"); + std::smatch match; + + std::string template_copy = header_template; + while (std::regex_search(template_copy, match, placeholder_regex)) { + std::string key = match[1].str(); + replacements.push_back(key); + template_copy = match.suffix().str(); + } + + return replacements; +} + +std::vector ReplaceHeaderPlaceholders( + const std::string& header_template, + const std::unordered_map& replacements) { + std::vector result; + size_t start = 0; + size_t end = header_template.find("}}"); + + while (end != std::string::npos) { + // Extract the part + std::string part = header_template.substr(start, end - start + 2); + + // Replace variables in this part + for (const auto& var : replacements) { + std::string placeholder = "{{" + var.first + "}}"; + size_t pos = part.find(placeholder); + if (pos != std::string::npos) { + part.replace(pos, placeholder.length(), var.second); + } + } + + // Trim whitespace + part.erase(0, part.find_first_not_of(" \t\n\r\f\v")); + part.erase(part.find_last_not_of(" \t\n\r\f\v") + 1); + + // Add to result if not empty + if (!part.empty()) { + result.push_back(part); + } + + // Move to next part + start = end + 2; + end = header_template.find("}}", start); + } + + // Process any remaining part + if (start < header_template.length()) { + std::string part = header_template.substr(start); + + // Replace variables in this part + for (const auto& var : replacements) { + std::string placeholder = "{{" + var.first + "}}"; + size_t pos = part.find(placeholder); + if (pos != std::string::npos) { + part.replace(pos, placeholder.length(), var.second); + } + } + + // Trim whitespace + part.erase(0, part.find_first_not_of(" \t\n\r\f\v")); + part.erase(part.find_last_not_of(" \t\n\r\f\v") + 1); + + if (!part.empty()) { + result.push_back(part); + } + } + return result; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 6361077dd..0d7ecbef1 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -1,8 +1,10 @@ #include "remote_engine.h" #include #include +#include #include #include +#include "helper.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" namespace remote_engine { @@ -12,13 +14,6 @@ constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; constexpr const int kFileLoggerOption = 0; -bool is_anthropic(const std::string& model) { - return model.find("claude") != std::string::npos; -} - -bool is_openai(const std::string& model) { - return model.find("gpt") != std::string::npos; -} constexpr const std::array kAnthropicModels = { "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", @@ -31,9 +26,20 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { auto* context = static_cast(userdata); std::string chunk(ptr, size * nmemb); + CTL_DBG(chunk); + auto check_error = json_helper::ParseJsonString(chunk); + if (check_error.isMember("error")) { + CTL_WRN(chunk); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = k400BadRequest; + (*context->callback)(std::move(status), std::move(check_error)); + return size * nmemb; + } context->buffer += chunk; - // Process complete lines size_t pos; while ((pos = context->buffer.find('\n')) != std::string::npos) { @@ -59,24 +65,23 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, // Parse the JSON Json::Value chunk_json; - if (!is_openai(context->model)) { - std::string s = line.substr(6); - try { - auto root = json_helper::ParseJsonString(s); - root["model"] = context->model; - root["id"] = context->id; - root["stream"] = true; - auto result = context->renderer.Render(context->stream_template, root); - CTL_DBG(result); - chunk_json["data"] = "data: " + result + "\n\n"; - } catch (const std::exception& e) { - CTL_WRN("JSON parse error: " << e.what()); + std::string s = line; + if (line.size() > 6) + s = line.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + if (root.getMemberNames().empty()) continue; - } - } else { - chunk_json["data"] = line + "\n\n"; + root["model"] = context->model; + root["id"] = context->id; + root["stream"] = true; + auto result = context->renderer.Render(context->stream_template, root); + CTL_DBG(result); + chunk_json["data"] = "data: " + result + "\n\n"; + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); + continue; } - Json::Reader reader; Json::Value status; status["is_done"] = false; @@ -102,17 +107,17 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( return response; } - std::string full_url = - config.transform_req["chat_completions"]["url"].as(); + std::string full_url = chat_url_; - struct curl_slist* headers = nullptr; - if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); + if (config.transform_req["chat_completions"]["url"]) { + full_url = + config.transform_req["chat_completions"]["url"].as(); } + CTL_DBG("full_url: " << full_url); - if (is_anthropic(config.model)) { - std::string v = "anthropic-version: " + config.version; - headers = curl_slist_append(headers, v.c_str()); + struct curl_slist* headers = nullptr; + for (auto const& h : header_) { + headers = curl_slist_append(headers, h.c_str()); } headers = curl_slist_append(headers, "Content-Type: application/json"); @@ -121,6 +126,12 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Connection: keep-alive"); std::string stream_template = chat_res_template_; + if (config.transform_resp["chat_completions"] && + config.transform_resp["chat_completions"]["template"]) { + // Model level overrides engine level + stream_template = + config.transform_resp["chat_completions"]["template"].as(); + } StreamContext context{ std::make_shared>( @@ -174,6 +185,21 @@ std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, return result; } +std::vector ReplaceHeaderPlaceholders( + const std::string& template_str, Json::Value json_body) { + CTL_DBG(template_str); + auto keys = GetReplacements(template_str); + if (keys.empty()) + return std::vector{}; + std::unordered_map replacements; + for (auto const& k : keys) { + if (json_body.isMember(k)) { + replacements.insert({k, json_body[k].asString()}); + } + } + return ReplaceHeaderPlaceholders(template_str, replacements); +} + static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, std::string* data) { data->append(ptr, size * nmemb); @@ -181,7 +207,7 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, } RemoteEngine::RemoteEngine(const std::string& engine_name) - : engine_name_(engine_name) { + : engine_name_(engine_name), q_(1 /*n_parallel*/, engine_name) { curl_global_init(CURL_GLOBAL_ALL); } @@ -199,7 +225,9 @@ RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( return nullptr; } -CurlResponse RemoteEngine::MakeGetModelsRequest() { +CurlResponse RemoteEngine::MakeGetModelsRequest( + const std::string& url, const std::string& api_key, + const std::string& header_template) { CURL* curl = curl_easy_init(); CurlResponse response; @@ -209,13 +237,14 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() { return response; } - std::string full_url = metadata_["get_models_url"].asString(); + std::string api_key_header = + ReplaceApiKeyPlaceholder(header_template, api_key); struct curl_slist* headers = nullptr; - headers = curl_slist_append(headers, api_key_template_.c_str()); + headers = curl_slist_append(headers, api_key_header.c_str()); headers = curl_slist_append(headers, "Content-Type: application/json"); - curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); std::string response_string; @@ -246,18 +275,19 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( response.error_message = "Failed to initialize CURL"; return response; } - std::string full_url = - config.transform_req["chat_completions"]["url"].as(); + std::string full_url = chat_url_; - struct curl_slist* headers = nullptr; - if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); + if (config.transform_req["chat_completions"]["url"]) { + full_url = + config.transform_req["chat_completions"]["url"].as(); } + CTL_DBG("full_url: " << full_url); - if (is_anthropic(config.model)) { - std::string v = "anthropic-version: " + config.version; - headers = curl_slist_append(headers, v.c_str()); + struct curl_slist* headers = nullptr; + for (auto const& h : header_) { + headers = curl_slist_append(headers, h.c_str()); } + headers = curl_slist_append(headers, "Content-Type: application/json"); curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); @@ -286,42 +316,30 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( bool RemoteEngine::LoadModelConfig(const std::string& model, const std::string& yaml_path, - const std::string& api_key) { + const Json::Value& body) { try { YAML::Node config = YAML::LoadFile(yaml_path); ModelConfig model_config; model_config.model = model; - if (is_anthropic(model)) { - if (!config["version"]) { - CTL_ERR("Missing version for model: " << model); - return false; - } - model_config.version = config["version"].as(); - } - - // Required fields - if (!config["api_key_template"]) { - LOG_ERROR << "Missing required fields in config for model " << model; - return false; - } - - model_config.api_key = api_key; + model_config.api_key = body["api_key"].asString(); // model_config.url = ; // Optional fields - if (config["api_key_template"]) { - api_key_template_ = ReplaceApiKeyPlaceholder( - config["api_key_template"].as(), api_key); + if (auto s = config["header_template"]; s && !s.as().empty()) { + header_ = ReplaceHeaderPlaceholders(s.as(), body); + for (auto const& h : header_) { + CTL_DBG("header: " << h); + } } - if (config["TransformReq"]) { - model_config.transform_req = config["TransformReq"]; + if (config["transform_req"]) { + model_config.transform_req = config["transform_req"]; } else { - LOG_WARN << "Missing TransformReq in config for model " << model; + LOG_WARN << "Missing transform_req in config for model " << model; } - if (config["TransformResp"]) { - model_config.transform_resp = config["TransformResp"]; + if (config["transform_resp"]) { + model_config.transform_resp = config["transform_resp"]; } else { - LOG_WARN << "Missing TransformResp in config for model " << model; + LOG_WARN << "Missing transform_resp in config for model " << model; } model_config.is_loaded = true; @@ -393,34 +411,54 @@ void RemoteEngine::LoadModel( const std::string& model_path = (*json_body)["model_path"].asString(); const std::string& api_key = (*json_body)["api_key"].asString(); - if (!LoadModelConfig(model, model_path, api_key)) { - Json::Value error; - error["error"] = "Failed to load model configuration"; - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - callback(std::move(status), std::move(error)); - return; - } if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; - if (!metadata_["TransformReq"].isNull() && - !metadata_["TransformReq"]["chat_completions"].isNull() && - !metadata_["TransformReq"]["chat_completions"]["template"].isNull()) { + if (!metadata_["transform_req"].isNull() && + !metadata_["transform_req"]["chat_completions"].isNull() && + !metadata_["transform_req"]["chat_completions"]["template"].isNull()) { chat_req_template_ = - metadata_["TransformReq"]["chat_completions"]["template"].asString(); + metadata_["transform_req"]["chat_completions"]["template"].asString(); CTL_INF(chat_req_template_); } - if (!metadata_["TransformResp"].isNull() && - !metadata_["TransformResp"]["chat_completions"].isNull() && - !metadata_["TransformResp"]["chat_completions"]["template"].isNull()) { + if (!metadata_["transform_resp"].isNull() && + !metadata_["transform_resp"]["chat_completions"].isNull() && + !metadata_["transform_resp"]["chat_completions"]["template"].isNull()) { chat_res_template_ = - metadata_["TransformResp"]["chat_completions"]["template"].asString(); + metadata_["transform_resp"]["chat_completions"]["template"] + .asString(); CTL_INF(chat_res_template_); } + + if (!metadata_["transform_req"].isNull() && + !metadata_["transform_req"]["chat_completions"].isNull() && + !metadata_["transform_req"]["chat_completions"]["url"].isNull()) { + chat_url_ = + metadata_["transform_req"]["chat_completions"]["url"].asString(); + CTL_INF(chat_url_); + } + } + + if (json_body->isMember("metadata")) { + if (!metadata_["header_template"].isNull()) { + header_ = ReplaceHeaderPlaceholders( + metadata_["header_template"].asString(), *json_body); + for (auto const& h : header_) { + CTL_DBG("header: " << h); + } + } + } + + if (!LoadModelConfig(model, model_path, *json_body)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; } Json::Value response; @@ -501,15 +539,6 @@ void RemoteEngine::HandleChatCompletion( // Transform request std::string result; try { - // Check if required YAML nodes exist - if (!model_config->transform_req["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_req"); - } - if (!model_config->transform_req["chat_completions"]["template"]) { - throw std::runtime_error("Missing 'template' node in chat_completions"); - } - // Validate JSON body if (!json_body || json_body->isNull()) { throw std::runtime_error("Invalid or null JSON body"); @@ -517,12 +546,16 @@ void RemoteEngine::HandleChatCompletion( // Get template string with error check std::string template_str; - try { + if (!chat_req_template_.empty()) { + CTL_DBG("Use engine transform request template: " << chat_req_template_); + template_str = chat_req_template_; + } + if (model_config->transform_req["chat_completions"] && + model_config->transform_req["chat_completions"]["template"]) { + // Model level overrides engine level template_str = model_config->transform_req["chat_completions"]["template"] .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error("Failed to convert template node to string: " + - std::string(e.what())); + CTL_DBG("Use model transform request template: " << template_str); } // Render with error handling @@ -540,7 +573,9 @@ void RemoteEngine::HandleChatCompletion( } if (is_stream) { - MakeStreamingChatCompletionRequest(*model_config, result, callback); + q_.runTaskInQueue([this, model_config, result, cb = std::move(callback)] { + MakeStreamingChatCompletionRequest(*model_config, result, cb); + }); } else { auto response = MakeChatCompletionRequest(*model_config, result); @@ -579,33 +614,14 @@ void RemoteEngine::HandleChatCompletion( CTL_DBG( "Use engine transform response template: " << chat_res_template_); template_str = chat_res_template_; - } else { - // Check if required YAML nodes exist - if (!model_config->transform_resp["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_resp"); - } - if (!model_config->transform_resp["chat_completions"]["template"]) { - throw std::runtime_error( - "Missing 'template' node in chat_completions"); - } - - // Validate JSON body - if (!response_json || response_json.isNull()) { - throw std::runtime_error("Invalid or null JSON body"); - } - - // Get template string with error check - - try { - template_str = - model_config->transform_resp["chat_completions"]["template"] - .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error( - "Failed to convert template node to string: " + - std::string(e.what())); - } + } + if (model_config->transform_resp["chat_completions"] && + model_config->transform_resp["chat_completions"]["template"]) { + // Model level overrides engine level + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + CTL_DBG("Use model transform request template: " << template_str); } try { @@ -686,9 +702,10 @@ void RemoteEngine::HandleEmbedding( callback(Json::Value(), Json::Value()); } -Json::Value RemoteEngine::GetRemoteModels() { - if (metadata_["get_models_url"].isNull() || - metadata_["get_models_url"].asString().empty()) { +Json::Value RemoteEngine::GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) { + if (url.empty()) { if (engine_name_ == kAnthropicEngine) { Json::Value json_resp; Json::Value model_array(Json::arrayValue); @@ -709,20 +726,19 @@ Json::Value RemoteEngine::GetRemoteModels() { return Json::Value(); } } else { - auto response = MakeGetModelsRequest(); + auto response = MakeGetModelsRequest(url, api_key, header_template); if (response.error) { Json::Value error; error["error"] = response.error_message; + CTL_WRN(response.error_message); return error; } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; + CTL_DBG(response.body); + auto body_json = json_helper::ParseJsonString(response.body); + if (body_json.isMember("error")) { + return body_json["error"]; } - return response_json; + return body_json; } } diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 6f08b5403..bc6d534c5 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -9,6 +9,7 @@ #include #include "cortex-common/remote_enginei.h" #include "extensions/template_renderer.h" +#include "trantor/utils/ConcurrentTaskQueue.h" #include "utils/engine_constants.h" #include "utils/file_logger.h" // Helper for CURL response @@ -50,8 +51,10 @@ class RemoteEngine : public RemoteEngineI { Json::Value metadata_; std::string chat_req_template_; std::string chat_res_template_; - std::string api_key_template_; + std::vector header_; std::string engine_name_; + std::string chat_url_; + trantor::ConcurrentTaskQueue q_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, @@ -60,11 +63,13 @@ class RemoteEngine : public RemoteEngineI { CurlResponse MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, const std::function& callback); - CurlResponse MakeGetModelsRequest(); + CurlResponse MakeGetModelsRequest(const std::string& url, + const std::string& api_key, + const std::string& header_template); // Internal model management bool LoadModelConfig(const std::string& model, const std::string& yaml_path, - const std::string& api_key); + const Json::Value& body); ModelConfig* GetModelConfig(const std::string& model); public: @@ -97,7 +102,9 @@ class RemoteEngine : public RemoteEngineI { std::shared_ptr json_body, std::function&& callback) override; - Json::Value GetRemoteModels() override; + Json::Value GetRemoteModels(const std::string& url, + const std::string& api_key, + const std::string& header_template) override; }; } // namespace remote_engine \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c6b107af3..a80c5fe60 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -6,10 +6,10 @@ #include #include "algorithm" +#include "config/model_config.h" #include "database/engines.h" - +#include "database/models.h" #include "extensions/python-engine/python_engine.h" - #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" @@ -736,9 +736,11 @@ cpp::result EngineService::LoadEngine( return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); - - CTL_INF("Loaded engine: " << engine_name); + if (!IsEngineLoaded(engine_name)) { + engines_[engine_name].engine = + new remote_engine::RemoteEngine(engine_name); + CTL_INF("Loaded engine: " << engine_name); + } return {}; } @@ -1074,7 +1076,8 @@ cpp::result EngineService::GetEngineById( cpp::result EngineService::GetEngineByNameAndVariant( - const std::string& engine_name, const std::optional variant) { + const std::string& engine_name, + const std::optional variant) const { assert(db_service_); auto get_res = db_service_->GetEngineByNameAndVariant(engine_name, variant); @@ -1123,17 +1126,29 @@ cpp::result EngineService::GetRemoteModels( return cpp::fail(r.error()); } + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + if (!IsEngineLoaded(engine_name)) { - auto exist_engine = GetEngineByNameAndVariant(engine_name); - if (exist_engine.has_error()) { - return cpp::fail("Remote engine '" + engine_name + "' is not installed"); - } engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); } + auto remote_engine_json = exist_engine.value().ToJson(); auto& e = std::get(engines_[engine_name].engine); - auto res = e->GetRemoteModels(); + auto url = remote_engine_json["metadata"]["get_models_url"].asString(); + auto api_key = remote_engine_json["api_key"].asString(); + auto header_template = + remote_engine_json["metadata"]["header_template"].asString(); + if (url.empty()) + CTL_WRN("url is empty"); + if (api_key.empty()) + CTL_WRN("api_key is empty"); + if (header_template.empty()) + CTL_WRN("header_template is empty"); + auto res = e->GetRemoteModels(url, api_key, header_template); if (!res["error"].isNull()) { return cpp::fail(res["error"].asString()); } else { @@ -1141,7 +1156,7 @@ cpp::result EngineService::GetRemoteModels( } } -bool EngineService::IsRemoteEngine(const std::string& engine_name) { +bool EngineService::IsRemoteEngine(const std::string& engine_name) const { auto ne = Repo2Engine(engine_name); auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; for (auto const& le : local_engines) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index a460582c6..f98037bab 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -139,7 +139,7 @@ class EngineService : public EngineServiceI { cpp::result GetEngineByNameAndVariant( const std::string& engine_name, - const std::optional variant = std::nullopt) override; + const std::optional variant = std::nullopt) const override; cpp::result UpsertEngine( const std::string& engine_name, const std::string& type, @@ -155,7 +155,7 @@ class EngineService : public EngineServiceI { void RegisterEngineLibPath(); - bool IsRemoteEngine(const std::string& engine_name) override; + bool IsRemoteEngine(const std::string& engine_name) const override; private: bool IsEngineLoaded(const std::string& engine); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 3668fb6fe..057b6f716 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -394,3 +394,8 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } + +std::string InferenceService::GetEngineByModelId( + const std::string& model_id) const { + return model_service_.lock()->GetEngineByModelId(model_id); +} diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index f23be3f23..794110f99 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -69,6 +69,8 @@ class InferenceService { model_service_ = model_service; } + std::string GetEngineByModelId(const std::string& model_id) const; + private: std::shared_ptr engine_service_; std::weak_ptr model_service_; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 74767a9b2..f79c20859 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -768,7 +768,8 @@ cpp::result ModelService::DeleteModel( // Remove yaml file std::filesystem::remove(yaml_fp); // Remove model files if they are not imported locally - if (model_entry.value().branch_name != "imported") { + if (model_entry.value().branch_name != "imported" && + !engine_svc_->IsRemoteEngine(mc.engine)) { if (mc.files.size() > 0) { if (mc.engine == kLlamaRepo || mc.engine == kLlamaEngine) { for (auto& file : mc.files) { @@ -890,7 +891,7 @@ cpp::result ModelService::StartModel( // Running remote model if (engine_svc_->IsRemoteEngine(mc.engine)) { - + engine_svc_->LoadEngine(mc.engine); config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( fmu::ToAbsoluteCortexDataPath( @@ -906,6 +907,10 @@ cpp::result ModelService::StartModel( json_data = remote_mc.ToJson(); json_data["api_key"] = std::move(remote_engine_json["api_key"]); + if (auto v = remote_engine_json["version"].asString(); + !v.empty() && v != "latest") { + json_data["version"] = v; + } json_data["model_path"] = fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) @@ -1374,5 +1379,27 @@ ModelService::GetModelMetadata(const std::string& model_id) const { std::shared_ptr ModelService::GetCachedModelMetadata( const std::string& model_id) const { + if (loaded_model_metadata_map_.find(model_id) == + loaded_model_metadata_map_.end()) + return nullptr; return loaded_model_metadata_map_.at(model_id); } + +std::string ModelService::GetEngineByModelId( + const std::string& model_id) const { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + auto model_entry = db_service_->GetModelInfo(model_id); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return ""; + } + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + CTL_DBG(mc.engine); + return mc.engine; +} \ No newline at end of file diff --git a/engine/services/model_service.h b/engine/services/model_service.h index cc659fea5..a668b27ba 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -96,6 +96,8 @@ class ModelService { std::shared_ptr GetCachedModelMetadata( const std::string& model_id) const; + std::string GetEngineByModelId(const std::string& model_id) const; + private: /** * Handle downloading model which have following pattern: author/model_name diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 5f1b85044..3bb3cdca3 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -1,3 +1,7 @@ +#include +#include +#include +#include "extensions/remote-engine/helper.h" #include "extensions/template_renderer.h" #include "gtest/gtest.h" #include "utils/json_helper.h" @@ -25,19 +29,22 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { {% endfor %} ] {% endif %} + {% if not loop.is_last %},{% endif %} {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} "{{ key }}": {{ tojson(value) }} + {% if not loop.is_last %},{% endif %} {% endif %} - {% if not loop.is_last %},{% endif %} {% endfor %} })"; { std::string message_with_system = R"({ + "engine" : "anthropic", + "max_tokens" : 1024, "messages": [ {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, {"role": "user", "content": "Hello, world"} ], "model": "claude-3-5-sonnet-20241022", - "max_tokens": 1024, + "stream" : true })"; auto data = json_helper::ParseJsonString(message_with_system); @@ -78,4 +85,160 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { EXPECT_EQ(data["messages"][0]["content"].asString(), res_json["messages"][0]["content"].asString()); } +} + +TEST_F(RemoteEngineTest, OpenAiResponse) { + std::string tpl = R"({ + {% set first = true %} + {% for key, value in input_request %} + {% if key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "stream" or key == "object" or key == "usage" %} + {% if not first %},{% endif %} + "{{ key }}": {{ tojson(value) }} + {% set first = false %} + {% endif %} + {% endfor %} + })"; + std::string message = R"( + { + "choices": [ + { + "delta": { + "content": " questions" + }, + "finish_reason": null, + "index": 0 + } + ], + "created": 1735372587, + "id": "", + "model": "o1-preview", + "object": "chat.completion.chunk", + "stream": true, + "system_fingerprint": "fp_1ddf0263de" + })"; + auto data = json_helper::ParseJsonString(message); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["created"].asInt(), res_json["created"].asInt()); + EXPECT_EQ(data["choices"][0]["delta"]["content"].asString(), + res_json["choices"][0]["delta"]["content"].asString()); +} + +TEST_F(RemoteEngineTest, AnthropicResponse) { + std::string tpl = R"( + {% if input_request.stream %} + {"object": "chat.completion.chunk", "model": "{{ input_request.model }}", "choices": [{"index": 0, "delta": { {% if input_request.type == "message_start" %} "role": "assistant", "content": null {% else if input_request.type == "ping" %} "role": "assistant", "content": null {% else if input_request.type == "content_block_delta" %} "role": "assistant", "content": "{{ input_request.delta.text }}" {% else if input_request.type == "content_block_stop" %} "role": "assistant", "content": null {% else if input_request.type == "content_block_stop" %} "role": "assistant", "content": null {% endif %} }, {% if input_request.type == "content_block_stop" %} "finish_reason": "stop" {% else %} "finish_reason": null {% endif %} }]} + {% else %} + {"id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [{ + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": {% if not input_request.content %} null {% else if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% else %} null {% endif %}, + "refusal": null }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" } ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { "cached_tokens": 0 }, + "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, + "system_fingerprint": "fp_6b68a8204b"} + {% endif %})"; + std::string message = R"( + { + "content": [], + "id": "msg_01SckpnDyChcmmawQsWHr8CH", + "model": "claude-3-opus-20240229", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "stream": false, + "type": "message", + "usage": { + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "input_tokens": 130, + "output_tokens": 3 + } + })"; + auto data = json_helper::ParseJsonString(message); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["created"].asInt(), res_json["created"].asInt()); + EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull()); +} + +TEST_F(RemoteEngineTest, HeaderTemplate) { + { + std::string header_template = + R"(x-api-key: {{api_key}} anthropic-version: {{version}})"; + Json::Value test_value; + test_value["api_key"] = "test"; + test_value["version"] = "test_version"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "x-api-key: test"); + EXPECT_EQ(result[1], "anthropic-version: test_version"); + } + + { + std::string header_template = + R"(x-api-key: {{api_key}} anthropic-version: test_version)"; + Json::Value test_value; + test_value["api_key"] = "test"; + test_value["version"] = "test_version"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "x-api-key: test"); + EXPECT_EQ(result[1], "anthropic-version: test_version"); + } + + { + std::string header_template = R"(Authorization: Bearer {{api_key}}")"; + Json::Value test_value; + test_value["api_key"] = "test"; + std::unordered_map replacements; + auto r = remote_engine::GetReplacements(header_template); + for (auto s : r) { + if (test_value.isMember(s)) { + replacements.insert({s, test_value[s].asString()}); + } + } + + auto result = + remote_engine::ReplaceHeaderPlaceholders(header_template, replacements); + + EXPECT_EQ(result[0], "Authorization: Bearer test"); + } } \ No newline at end of file