diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index d855c8f61..a38dbe70e 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -140,10 +140,10 @@ cpp::result DownloadService::GetFileSize( curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); auto headers = curl_utils::GetHeaders(url); - if (headers.has_value()) { + if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -227,10 +227,10 @@ cpp::result DownloadService::Download( curl_easy_setopt(curl, CURLOPT_URL, download_item.downloadUrl.c_str()); auto headers = curl_utils::GetHeaders(download_item.downloadUrl); - if (headers.has_value()) { + if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -469,7 +469,7 @@ void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item, auto headers = curl_utils::GetHeaders(item.downloadUrl); if (headers) { curl_slist* curl_headers = nullptr; - for (const auto& [key, value] : headers.value()) { + for (const auto& [key, value] : headers->m) { curl_headers = curl_slist_append(curl_headers, (key + ": " + value).c_str()); } diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc index 71f263a6a..1817948e2 100644 --- a/engine/utils/curl_utils.cc +++ b/engine/utils/curl_utils.cc @@ -9,12 +9,24 @@ namespace curl_utils { namespace { -size_t WriteCallback(void* contents, size_t size, size_t nmemb, - std::string* output) { - size_t totalSize = size * nmemb; - output->append((char*)contents, totalSize); - return totalSize; -} +class CurlResponse { + public: + static size_t WriteCallback(char* buffer, size_t size, size_t nitems, + void* userdata) { + auto* response = static_cast(userdata); + return response->Append(buffer, size * nitems); + } + + size_t Append(const char* buffer, size_t size) { + data_.append(buffer, size); + return size; + } + + const std::string& GetData() const { return data_; } + + private: + std::string data_; +}; void SetUpProxy(CURL* handle, const std::string& url) { auto config = file_manager_utils::GetCortexConfig(); @@ -59,19 +71,18 @@ void SetUpProxy(CURL* handle, const std::string& url) { } } // namespace -std::optional> GetHeaders( - const std::string& url) { +std::shared_ptr
GetHeaders(const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.has_error()) { - return std::nullopt; + return nullptr; } if (url_obj->host == kHuggingFaceHost) { - std::unordered_map headers{}; - headers["Content-Type"] = "application/json"; + auto headers = std::make_shared
(); + headers->m["Content-Type"] = "application/json"; auto const& token = file_manager_utils::GetCortexConfig().huggingFaceToken; if (!token.empty()) { - headers["Authorization"] = "Bearer " + token; + headers->m["Authorization"] = "Bearer " + token; // for debug purpose auto min_token_size = 6; @@ -87,15 +98,15 @@ std::optional> GetHeaders( } if (url_obj->host == kGitHubHost) { - std::unordered_map headers{}; - headers["Accept"] = "application/vnd.github.v3+json"; + auto headers = std::make_shared
(); + headers->m["Accept"] = "application/vnd.github.v3+json"; // github API requires user-agent https://docs.github.com/en/rest/using-the-rest-api/getting-started-with-the-rest-api?apiVersion=2022-11-28#user-agent auto user_agent = file_manager_utils::GetCortexConfig().gitHubUserAgent; auto gh_token = file_manager_utils::GetCortexConfig().gitHubToken; - headers["User-Agent"] = + headers->m["User-Agent"] = user_agent.empty() ? kDefaultGHUserAgent : user_agent; if (!gh_token.empty()) { - headers["Authorization"] = "Bearer " + gh_token; + headers->m["Authorization"] = "Bearer " + gh_token; // for debug purpose auto min_token_size = 6; @@ -109,7 +120,7 @@ std::optional> GetHeaders( return headers; } - return std::nullopt; + return nullptr; } cpp::result SimpleGet(const std::string& url, @@ -122,8 +133,8 @@ cpp::result SimpleGet(const std::string& url, auto headers = GetHeaders(url); curl_slist* curl_headers = nullptr; - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { + if (headers) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } @@ -131,12 +142,14 @@ cpp::result SimpleGet(const std::string& url, curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); } - std::string readBuffer; + auto* response = new CurlResponse(); + std::shared_ptr s(response, + std::default_delete()); SetUpProxy(curl, url); curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlResponse::WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); if (timeout > 0) { curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout); } @@ -155,10 +168,10 @@ cpp::result SimpleGet(const std::string& url, if (http_code >= 400) { CTL_ERR("HTTP request failed with status code: " + std::to_string(http_code)); - return cpp::fail(readBuffer); + return cpp::fail(response->GetData()); } - return readBuffer; + return response->GetData(); } cpp::result SimpleRequest( @@ -176,13 +189,15 @@ cpp::result SimpleRequest( curl_slist_append(curl_headers, "Content-Type: application/json"); curl_headers = curl_slist_append(curl_headers, "Expect:"); - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { + if (headers) { + for (const auto& [key, value] : headers->m) { auto header = key + ": " + value; curl_headers = curl_slist_append(curl_headers, header.c_str()); } } - std::string readBuffer; + auto* response = new CurlResponse(); + std::shared_ptr s(response, + std::default_delete()); SetUpProxy(curl, url); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); @@ -196,8 +211,8 @@ cpp::result SimpleRequest( curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "DELETE"); } curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlResponse::WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); @@ -221,10 +236,10 @@ cpp::result SimpleRequest( if (http_code >= 400) { CTL_ERR("HTTP request failed with status code: " + std::to_string(http_code)); - return cpp::fail(readBuffer); + return cpp::fail(response->GetData()); } - return readBuffer; + return response->GetData(); } cpp::result ReadRemoteYaml(const std::string& url) { diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 64b5fc339..69369def0 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -15,8 +15,11 @@ enum class RequestType { GET, PATCH, POST, DEL }; namespace curl_utils { -std::optional> GetHeaders( - const std::string& url); +struct Header { + std::unordered_map m; +}; + +std::shared_ptr
GetHeaders(const std::string& url); cpp::result SimpleGet(const std::string& url, const int timeout = -1);