Skip to content

Commit 7f8ed0d

Browse files
authored
Merge pull request #311 from janhq/j/update-engine-interface
feat: update engine interface to allow 3rd to provide engine
2 parents b4aa6ab + 7332878 commit 7f8ed0d

File tree

4 files changed

+117
-4
lines changed

4 files changed

+117
-4
lines changed

CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ project(cortex.llamacpp)
33
SET(TARGET engine)
44

55
if(UNIX AND NOT APPLE)
6+
add_compile_definitions(LINUX)
67
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -pthread")
78
add_compile_options(-fPIC -pthread)
89
find_package(Threads)
@@ -53,4 +54,4 @@ target_include_directories(${TARGET} PRIVATE
5354
${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp
5455
${THIRD_PARTY_PATH}/include)
5556

56-
target_compile_features(${TARGET} PUBLIC cxx_std_17)
57+
target_compile_features(${TARGET} PUBLIC cxx_std_17)

base/cortex-common/enginei.h

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

3+
#include <filesystem>
34
#include <functional>
45
#include <memory>
6+
#include <vector>
57

68
#include "json/value.h"
79
#include "trantor/utils/Logger.h"
@@ -10,8 +12,37 @@
1012
// Note: only append new function to keep the compatibility.
1113
class EngineI {
1214
public:
15+
struct RegisterLibraryOption {
16+
std::vector<std::filesystem::path> paths;
17+
};
18+
19+
struct EngineLoadOption {
20+
// engine
21+
std::filesystem::path engine_path;
22+
std::filesystem::path cuda_path;
23+
bool custom_engine_path;
24+
25+
// logging
26+
std::filesystem::path log_path;
27+
int max_log_lines;
28+
trantor::Logger::LogLevel log_level;
29+
};
30+
31+
struct EngineUnloadOption {
32+
bool unload_dll;
33+
};
34+
1335
virtual ~EngineI() {}
1436

37+
/**
38+
* Being called before starting process to register dependencies search paths.
39+
*/
40+
virtual void RegisterLibraryPath(RegisterLibraryOption opts) = 0;
41+
42+
virtual void Load(EngineLoadOption opts) = 0;
43+
44+
virtual void Unload(EngineUnloadOption opts) = 0;
45+
1546
virtual void HandleChatCompletion(
1647
std::shared_ptr<Json::Value> json_body,
1748
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
@@ -46,4 +77,4 @@ class EngineI {
4677
virtual void SetFileLogger(int max_log_lines,
4778
const std::string& log_path) = 0;
4879
virtual void SetLogLevel(trantor::Logger::LogLevel log_level) = 0;
49-
};
80+
};

src/llama_engine.cc

+68-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ std::shared_ptr<InferenceState> CreateInferenceState(LlamaServerContext& l) {
9191
}
9292

9393
Json::Value CreateEmbeddingPayload(const std::vector<float>& embedding,
94-
9594
int index, bool is_base64) {
9695
Json::Value dataItem;
9796
dataItem["object"] = "embedding";
@@ -114,6 +113,7 @@ Json::Value CreateEmbeddingPayload(const std::vector<float>& embedding,
114113

115114
return dataItem;
116115
}
116+
117117
std::vector<int> getUTF8Bytes(const std::string& str) {
118118
std::vector<int> bytes;
119119
for (unsigned char c : str) {
@@ -271,6 +271,71 @@ std::string CreateReturnJson(const std::string& id, const std::string& model,
271271
}
272272
} // namespace
273273

274+
void LlamaEngine::RegisterLibraryPath(RegisterLibraryOption opts) {
275+
#if defined(LINUX)
276+
const char* name = "LD_LIBRARY_PATH";
277+
std::string v;
278+
if (auto g = getenv(name); g) {
279+
v += g;
280+
}
281+
LOG_DEBUG << "LD_LIBRARY_PATH before: " << v;
282+
283+
for (const auto& p : opts.paths) {
284+
v += p.string() + ":" + v;
285+
}
286+
287+
setenv(name, v.c_str(), true);
288+
LOG_DEBUG << "LD_LIBRARY_PATH after: " << getenv(name);
289+
#endif
290+
}
291+
292+
void LlamaEngine::Load(EngineLoadOption opts) {
293+
LOG_INFO << "Loading engine..";
294+
295+
LOG_DEBUG << "Use custom engine path: " << opts.custom_engine_path;
296+
LOG_DEBUG << "Engine path: " << opts.engine_path.string();
297+
298+
SetFileLogger(opts.max_log_lines, opts.log_path.string());
299+
SetLogLevel(opts.log_level);
300+
301+
#if defined(_WIN32)
302+
if (!opts.custom_engine_path) {
303+
if (auto cookie = AddDllDirectory(opts.engine_path.c_str()); cookie != 0) {
304+
LOG_INFO << "Added dll directory: " << opts.engine_path.string();
305+
cookies_.push_back(cookie);
306+
} else {
307+
LOG_WARN << "Could not add dll directory: " << opts.engine_path.string();
308+
}
309+
310+
if (auto cuda_cookie = AddDllDirectory(opts.cuda_path.c_str());
311+
cuda_cookie != 0) {
312+
LOG_INFO << "Added cuda dll directory: " << opts.cuda_path.string();
313+
cookies_.push_back(cuda_cookie);
314+
} else {
315+
LOG_WARN << "Could not add cuda dll directory: "
316+
<< opts.cuda_path.string();
317+
}
318+
}
319+
#endif
320+
LOG_INFO << "Engine loaded successfully";
321+
}
322+
323+
void LlamaEngine::Unload(EngineUnloadOption opts) {
324+
LOG_INFO << "Unloading engine..";
325+
LOG_DEBUG << "Unload dll: " << opts.unload_dll;
326+
327+
if (opts.unload_dll) {
328+
#if defined(_WIN32)
329+
for (const auto& cookie : cookies_) {
330+
if (!RemoveDllDirectory(cookie)) {
331+
LOG_WARN << "Could not remove dll directory";
332+
}
333+
}
334+
#endif
335+
}
336+
LOG_INFO << "Engine unloaded successfully";
337+
}
338+
274339
LlamaEngine::LlamaEngine(int log_option) {
275340
trantor::Logger::setLogLevel(trantor::Logger::kInfo);
276341
if (log_option == kFileLoggerOption) {
@@ -303,6 +368,8 @@ LlamaEngine::~LlamaEngine() {
303368
}
304369
server_map_.clear();
305370
async_file_logger_.reset();
371+
372+
LOG_INFO << "LlamaEngine destructed successfully";
306373
}
307374

308375
void LlamaEngine::HandleChatCompletion(

src/llama_engine.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
23
#include <trantor/utils/AsyncFileLogger.h>
34
#include "chat_completion_request.h"
45
#include "cortex-common/enginei.h"
@@ -10,9 +11,18 @@
1011

1112
class LlamaEngine : public EngineI {
1213
public:
14+
constexpr static auto kEngineName = "cortex.llamacpp";
15+
1316
LlamaEngine(int log_option = 0);
1417
~LlamaEngine() final;
18+
1519
// #### Interface ####
20+
void RegisterLibraryPath(RegisterLibraryOption opts) final;
21+
22+
void Load(EngineLoadOption opts) final;
23+
24+
void Unload(EngineUnloadOption opts) final;
25+
1626
void HandleChatCompletion(
1727
std::shared_ptr<Json::Value> jsonBody,
1828
std::function<void(Json::Value&&, Json::Value&&)>&& callback) final;
@@ -74,4 +84,8 @@ class LlamaEngine : public EngineI {
7484

7585
bool print_version_ = true;
7686
std::unique_ptr<trantor::FileLogger> async_file_logger_;
77-
};
87+
88+
#if defined(_WIN32)
89+
std::vector<DLL_DIRECTORY_COOKIE> cookies_;
90+
#endif
91+
};

0 commit comments

Comments
 (0)