Skip to content

Commit

Permalink
chore: dllama-api tiny clean up. (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 26, 2024
1 parent ef1e312 commit 06cd0eb
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 175 deletions.
4 changes: 2 additions & 2 deletions examples/chat-api-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file.
// 2. Run this script: `node examples/chat-api-client.js`

const HOST = '127.0.0.1';
const PORT = 9990;
const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1';
const PORT = process.env.PORT ? Number(process.env.PORT) : 9990;

async function chat(messages, maxTokens) {
const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, {
Expand Down
214 changes: 41 additions & 173 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <unistd.h>
#include <vector>

#include "types.hpp"
#include "../../utils.hpp"
#include "../../socket.hpp"
#include "../../transformer.hpp"
Expand Down Expand Up @@ -126,154 +127,11 @@ class Router {
}
private:
static void notFoundHandler(Socket& client_socket) {
std::string header = "HTTP/1.1 404 Not Found\r\n";
client_socket.write(header.c_str(), header.length());
const char* data = "HTTP/1.1 404 Not Found\r\n";
client_socket.write(data, strlen(data));
}
};

struct ChatMessageDelta {
std::string role;
std::string content;

ChatMessageDelta() : role(""), content("") {}
ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};

// Define to_json for Delta struct
void to_json(json& j, const ChatMessageDelta& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}

struct ChatMessage {
std::string role;
std::string content;

ChatMessage() : role(""), content("") {}
ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};

// Define to_json for ChatMessage struct
void to_json(json& j, const ChatMessage& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}

struct ChunkChoice {
int index;
ChatMessageDelta delta;
std::string finish_reason;

ChunkChoice() : index(0) {}
};

// Define to_json for ChunkChoice struct
void to_json(json& j, const ChunkChoice& choice) {
j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}};
}

struct Choice {
int index;
ChatMessage message;
std::string finish_reason;

Choice() : finish_reason("") {}
Choice(ChatMessage &message_) : message(message_), finish_reason("") {}
Choice(const std::string &reason_) : finish_reason(reason_) {}
};

// Define to_json for Choice struct
void to_json(json& j, const Choice& choice) {
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
}

std::string createJsonResponse(std::string json) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
return oss.str();
}

struct ChatCompletionChunk {
std::string id;
std::string object;
long long created;
std::string model;
std::vector<ChunkChoice> choices;

ChatCompletionChunk(ChunkChoice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};

// Define to_json for ChatCompletionChunk struct
void to_json(json& j, const ChatCompletionChunk& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"choices", completion.choices}};
}

// Struct to represent the usage object
struct ChatUsage {
int prompt_tokens;
int completion_tokens;
int total_tokens;
ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {}
ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {}
};

// Struct to represent the chat completion object
struct ChatCompletion {
std::string id;
std::string object;
long long created; // Unix timestamp
std::string model;
std::vector<Choice> choices;
ChatUsage usage;

ChatCompletion(Choice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};

// Define to_json for ChatCompletion struct
void to_json(json& j, const ChatCompletion& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"choices", completion.choices}};
}

struct InferenceParams {
std::string prompt;
int max_tokens;
float temperature;
float top_p;
std::vector<std::string> stop;
bool stream;
unsigned long long seed;
};

std::vector<ChatMessage> parseChatMessages(json &json){
std::vector<ChatMessage> messages;
messages.reserve(json.size());

for (const auto& item : json) {
messages.emplace_back(
item["role"].template get<std::string>(),
item["content"].template get<std::string>()
);
}

return messages;
}

/*
Generally speaking, the tokenizer.config.json would contain the chat template for the model
and depending on the model used you could set the chat template to follow
Expand All @@ -282,26 +140,46 @@ for this code draft I am assuming the use of llama 3 instruct
*/
std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage> &messages){
std::ostringstream oss;

for (const auto& message : messages) {
oss << "<|start_header_id|>" << message.role << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>";
}

oss << "<|start_header_id|>assistant<|end_header_id|>\n\n";

return oss.str();
}

void writeChunk(Socket& socket, const std::string data, const bool stop) {
std::ostringstream formattedChunk;
formattedChunk << std::hex << data.size() << "\r\n" << data << "\r\n";
if (stop) {
formattedChunk << "0000\r\n\r\n";
}
socket.write(formattedChunk.str().c_str(), formattedChunk.str().size());
void writeJsonResponse(Socket& socket, std::string json) {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
std::string data = buffer.str();
socket.write(data.c_str(), data.size());
}

void writeStreamStartChunk(Socket& socket) {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";
std::string data = buffer.str();
socket.write(data.c_str(), data.size());
}

void writeStreamChunk(Socket& socket, const std::string data) {
std::ostringstream buffer;
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
std::string d = buffer.str();
socket.write(d.c_str(), d.size());
}

void writeStreamEndChunk(Socket& socket) {
const char* endChunk = "0000\r\n\r\n";
socket.write(endChunk, strlen(endChunk));
}

void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, const bool stop){
void writeChatCompletionChunk(Socket &socket, const std::string &delta, const bool stop){
ChunkChoice choice;
if (stop) {
choice.finish_reason = "stop";
Expand All @@ -312,15 +190,15 @@ void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, c

std::ostringstream buffer;
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
writeChunk(client_socket, buffer.str(), false);
writeStreamChunk(socket, buffer.str());

if (stop) {
writeChunk(client_socket, "data: [DONE]", true);
writeStreamChunk(socket, "data: [DONE]");
writeStreamEndChunk(socket);
}
}

void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
// Set inference arguments
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
Expand Down Expand Up @@ -360,13 +238,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
generated.get_allocator().allocate(params.max_tokens);

if (params.stream) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";

socket.write(oss.str().c_str(), oss.str().length());
writeStreamStartChunk(socket);
}

int promptLength = params.prompt.length();
Expand All @@ -386,8 +258,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i

if (pos < nPromptTokens - 1) {
token = promptTokens[pos + 1];
}
else {
} else {
int prevToken = token;
token = sampler->sample(logits);

Expand Down Expand Up @@ -420,7 +291,6 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
fflush(stdout);

generated.push_back(string);

if (params.stream) {
writeChatCompletionChunk(socket, string, false);
}
Expand All @@ -434,8 +304,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());

std::string chatJson = ((json)completion).dump();
std::string response = createJsonResponse(chatJson);
socket.write(response.c_str(), response.length());
writeJsonResponse(socket, chatJson);
} else {
writeChatCompletionChunk(socket, "", true);
}
Expand All @@ -444,12 +313,11 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
}

void handleModelsRequest(Socket& client_socket, HttpRequest& request) {
std::string response = createJsonResponse(
writeJsonResponse(client_socket,
"{ \"object\": \"list\","
"\"data\": ["
"{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }"
"] }");
client_socket.write(response.c_str(), response.length());
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
Expand Down Expand Up @@ -492,7 +360,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,

int main(int argc, char *argv[]) {
initQuants();

AppArgs args = AppArgs::parse(argc, argv, false);
App::run(&args, server);
return EXIT_SUCCESS;
Expand Down
Loading

0 comments on commit 06cd0eb

Please sign in to comment.