Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/stream request python engine #1829

Draft
wants to merge 5 commits into
base: feat/python-engine
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void server::FineTuning(

void server::Inference(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto json_body = req->getJsonObject();
LOG_TRACE << "Start inference";
auto q = std::make_shared<services::SyncQueue>();
auto ir = inference_svc_->HandleInference(q, req->getJsonObject());
Expand All @@ -141,20 +142,36 @@ void server::Inference(const HttpRequestPtr& req,
callback(resp);
return;
}
bool is_stream =
(*json_body).get("stream", false).asBool() ||
(*json_body).get("body", Json::Value()).get("stream", false).asBool();

LOG_TRACE << "Wait to inference";
auto [status, res] = q->wait_and_pop();
LOG_DEBUG << "response: " << res.toStyledString();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
callback(resp);
LOG_TRACE << "Done inference";
if (is_stream) {
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
if (!inference_svc_->HasFieldInReq(json_body, "engine")) {
return kLlamaRepo;
} else {
return (*(json_body)).get("engine", kLlamaRepo).asString();
}
}();
ProcessStreamRes(callback, q, engine_type, model_id);
} else {
auto [status, res] = q->wait_and_pop();
LOG_DEBUG << "response: " << res.toStyledString();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
callback(resp);
LOG_TRACE << "Done inference";
}
}

void server::RouteRequest(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {

auto json_body = req->getJsonObject();
LOG_TRACE << "Start route request";
auto q = std::make_shared<services::SyncQueue>();
auto ir = inference_svc_->HandleRouteRequest(q, req->getJsonObject());
Expand All @@ -167,14 +184,30 @@ void server::RouteRequest(
callback(resp);
return;
}
bool is_stream =
(*json_body).get("stream", false).asBool() ||
(*json_body).get("body", Json::Value()).get("stream", false).asBool();
LOG_TRACE << "Wait to route request";
auto [status, res] = q->wait_and_pop();
LOG_DEBUG << "response: " << res.toStyledString();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
callback(resp);
LOG_TRACE << "Done route request";
if (is_stream) {

auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
if (!inference_svc_->HasFieldInReq(json_body, "engine")) {
return kLlamaRepo;
} else {
return (*(json_body)).get("engine", kLlamaRepo).asString();
}
}();
ProcessStreamRes(callback, q, engine_type, model_id);
} else {
auto [status, res] = q->wait_and_pop();
LOG_DEBUG << "response: " << res.toStyledString();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
callback(resp);
LOG_TRACE << "Done route request";
}
}

void server::LoadModel(const HttpRequestPtr& req,
Expand Down
71 changes: 68 additions & 3 deletions engine/extensions/python-engine/python_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb,
return size * nmemb;
}

PythonEngine::PythonEngine() {
PythonEngine::PythonEngine():q_(4 /*n_parallel*/, "python_engine") {
curl_global_init(CURL_GLOBAL_ALL);
}

Expand Down Expand Up @@ -509,6 +509,62 @@ void PythonEngine::HandleChatCompletion(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {}

CurlResponse PythonEngine::MakeStreamPostRequest(
const std::string& model, const std::string& path, const std::string& body,
const std::function<void(Json::Value&&, Json::Value&&)>& callback) {
auto config = models_[model];
CURL* curl = curl_easy_init();
CurlResponse response;

if (!curl) {
response.error = true;
response.error_message = "Failed to initialize CURL";
return response;
}

std::string full_url = "http://localhost:" + config.port + path;

struct curl_slist* headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, "Accept: text/event-stream");
headers = curl_slist_append(headers, "Cache-Control: no-cache");
headers = curl_slist_append(headers, "Connection: keep-alive");

StreamContext context{
std::make_shared<std::function<void(Json::Value&&, Json::Value&&)>>(
callback),
""};

curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context);
curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L);

CURLcode res = curl_easy_perform(curl);

if (res != CURLE_OK) {
response.error = true;
response.error_message = curl_easy_strerror(res);

Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = 500;

Json::Value error;
error["error"] = response.error_message;
callback(std::move(status), std::move(error));
}

curl_slist_free_all(headers);
curl_easy_cleanup(curl);
return response;
}

void PythonEngine::HandleInference(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
Expand Down Expand Up @@ -544,7 +600,7 @@ void PythonEngine::HandleInference(

// Render with error handling
try {
transformed_request = renderer_.Render(transform_request, *json_body);
transformed_request = renderer_.Render(transform_request, body);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
Expand All @@ -563,7 +619,16 @@ void PythonEngine::HandleInference(

CurlResponse response;
if (method == "post") {
response = MakePostRequest(model, path, transformed_request);
if (body.isMember("stream") && body["stream"].asBool()) {
q_.runTaskInQueue(
[this, model, path, transformed_request, cb = std::move(callback)] {
MakeStreamPostRequest(model, path, transformed_request, cb);
});

return;
} else {
response = MakePostRequest(model, path, transformed_request);
}
} else if (method == "get") {
response = MakeGetRequest(model, path);
} else if (method == "delete") {
Expand Down
17 changes: 8 additions & 9 deletions engine/extensions/python-engine/python_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <unordered_map>
#include "config/model_config.h"
#include "trantor/utils/ConcurrentTaskQueue.h"
#include "cortex-common/EngineI.h"
#include "extensions/template_renderer.h"
#include "utils/file_logger.h"
Expand Down Expand Up @@ -36,25 +37,19 @@ static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
std::string chunk(ptr, size * nmemb);

context->buffer += chunk;

// Process complete lines
size_t pos;
while ((pos = context->buffer.find('\n')) != std::string::npos) {
std::string line = context->buffer.substr(0, pos);
context->buffer = context->buffer.substr(pos + 1);

LOG_INFO << "line: "<<line;
// Skip empty lines
if (line.empty() || line == "\r")
continue;

// Remove "data: " prefix if present
// if (line.substr(0, 6) == "data: ")
// {
// line = line.substr(6);
// }

// Skip [DONE] message
std::cout << line << std::endl;

if (line == "data: [DONE]") {
Json::Value status;
status["is_done"] = true;
Expand Down Expand Up @@ -97,6 +92,7 @@ class PythonEngine : public EngineI {
extensions::TemplateRenderer renderer_;
std::unique_ptr<trantor::FileLogger> async_file_logger_;
std::unordered_map<std::string, pid_t> processMap;
trantor::ConcurrentTaskQueue q_;

// Helper functions
CurlResponse MakePostRequest(const std::string& model,
Expand All @@ -106,7 +102,10 @@ class PythonEngine : public EngineI {
const std::string& path);
CurlResponse MakeDeleteRequest(const std::string& model,
const std::string& path);

CurlResponse MakeStreamPostRequest(
const std::string& model, const std::string& path,
const std::string& body,
const std::function<void(Json::Value&&, Json::Value&&)>& callback);
// Process manager functions
pid_t SpawnProcess(const std::string& model,
const std::vector<std::string>& command);
Expand Down