Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: dllama-api tiny clean up. #66

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/chat-api-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file.
// 2. Run this script: `node examples/chat-api-client.js`

const HOST = '127.0.0.1';
const PORT = 9990;
const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1';
const PORT = process.env.PORT ? Number(process.env.PORT) : 9990;

async function chat(messages, maxTokens) {
const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, {
Expand Down
214 changes: 41 additions & 173 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <unistd.h>
#include <vector>

#include "types.hpp"
#include "../../utils.hpp"
#include "../../socket.hpp"
#include "../../transformer.hpp"
Expand Down Expand Up @@ -126,154 +127,11 @@ class Router {
}
private:
static void notFoundHandler(Socket& client_socket) {
std::string header = "HTTP/1.1 404 Not Found\r\n";
client_socket.write(header.c_str(), header.length());
const char* data = "HTTP/1.1 404 Not Found\r\n";
client_socket.write(data, strlen(data));
}
};

struct ChatMessageDelta {
std::string role;
std::string content;

ChatMessageDelta() : role(""), content("") {}
ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};

// Define to_json for Delta struct
void to_json(json& j, const ChatMessageDelta& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}

struct ChatMessage {
std::string role;
std::string content;

ChatMessage() : role(""), content("") {}
ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};

// Define to_json for ChatMessage struct
void to_json(json& j, const ChatMessage& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}

struct ChunkChoice {
int index;
ChatMessageDelta delta;
std::string finish_reason;

ChunkChoice() : index(0) {}
};

// Define to_json for ChunkChoice struct
void to_json(json& j, const ChunkChoice& choice) {
j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}};
}

struct Choice {
int index;
ChatMessage message;
std::string finish_reason;

Choice() : finish_reason("") {}
Choice(ChatMessage &message_) : message(message_), finish_reason("") {}
Choice(const std::string &reason_) : finish_reason(reason_) {}
};

// Define to_json for Choice struct
void to_json(json& j, const Choice& choice) {
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
}

std::string createJsonResponse(std::string json) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
return oss.str();
}

struct ChatCompletionChunk {
std::string id;
std::string object;
long long created;
std::string model;
std::vector<ChunkChoice> choices;

ChatCompletionChunk(ChunkChoice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};

// Define to_json for ChatCompletionChunk struct
void to_json(json& j, const ChatCompletionChunk& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"choices", completion.choices}};
}

// Struct to represent the usage object
struct ChatUsage {
int prompt_tokens;
int completion_tokens;
int total_tokens;
ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {}
ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {}
};

// Struct to represent the chat completion object
struct ChatCompletion {
std::string id;
std::string object;
long long created; // Unix timestamp
std::string model;
std::vector<Choice> choices;
ChatUsage usage;

ChatCompletion(Choice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};

// Define to_json for ChatCompletion struct
void to_json(json& j, const ChatCompletion& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"choices", completion.choices}};
}

struct InferenceParams {
std::string prompt;
int max_tokens;
float temperature;
float top_p;
std::vector<std::string> stop;
bool stream;
unsigned long long seed;
};

std::vector<ChatMessage> parseChatMessages(json &json){
std::vector<ChatMessage> messages;
messages.reserve(json.size());

for (const auto& item : json) {
messages.emplace_back(
item["role"].template get<std::string>(),
item["content"].template get<std::string>()
);
}

return messages;
}

/*
Generally speaking, the tokenizer.config.json would contain the chat template for the model
and depending on the model used you could set the chat template to follow
Expand All @@ -282,26 +140,46 @@ for this code draft I am assuming the use of llama 3 instruct
*/
std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage> &messages){
std::ostringstream oss;

for (const auto& message : messages) {
oss << "<|start_header_id|>" << message.role << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>";
}

oss << "<|start_header_id|>assistant<|end_header_id|>\n\n";

return oss.str();
}

void writeChunk(Socket& socket, const std::string data, const bool stop) {
std::ostringstream formattedChunk;
formattedChunk << std::hex << data.size() << "\r\n" << data << "\r\n";
if (stop) {
formattedChunk << "0000\r\n\r\n";
}
socket.write(formattedChunk.str().c_str(), formattedChunk.str().size());
void writeJsonResponse(Socket& socket, std::string json) {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
std::string data = buffer.str();
socket.write(data.c_str(), data.size());
}

void writeStreamStartChunk(Socket& socket) {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";
std::string data = buffer.str();
socket.write(data.c_str(), data.size());
}

void writeStreamChunk(Socket& socket, const std::string data) {
std::ostringstream buffer;
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
std::string d = buffer.str();
socket.write(d.c_str(), d.size());
}

void writeStreamEndChunk(Socket& socket) {
const char* endChunk = "0000\r\n\r\n";
socket.write(endChunk, strlen(endChunk));
}

void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, const bool stop){
void writeChatCompletionChunk(Socket &socket, const std::string &delta, const bool stop){
ChunkChoice choice;
if (stop) {
choice.finish_reason = "stop";
Expand All @@ -312,15 +190,15 @@ void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, c

std::ostringstream buffer;
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
writeChunk(client_socket, buffer.str(), false);
writeStreamChunk(socket, buffer.str());

if (stop) {
writeChunk(client_socket, "data: [DONE]", true);
writeStreamChunk(socket, "data: [DONE]");
writeStreamEndChunk(socket);
}
}

void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
// Set inference arguments
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
Expand Down Expand Up @@ -360,13 +238,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
generated.get_allocator().allocate(params.max_tokens);

if (params.stream) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";

socket.write(oss.str().c_str(), oss.str().length());
writeStreamStartChunk(socket);
}

int promptLength = params.prompt.length();
Expand All @@ -386,8 +258,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i

if (pos < nPromptTokens - 1) {
token = promptTokens[pos + 1];
}
else {
} else {
int prevToken = token;
token = sampler->sample(logits);

Expand Down Expand Up @@ -420,7 +291,6 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
fflush(stdout);

generated.push_back(string);

if (params.stream) {
writeChatCompletionChunk(socket, string, false);
}
Expand All @@ -434,8 +304,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());

std::string chatJson = ((json)completion).dump();
std::string response = createJsonResponse(chatJson);
socket.write(response.c_str(), response.length());
writeJsonResponse(socket, chatJson);
} else {
writeChatCompletionChunk(socket, "", true);
}
Expand All @@ -444,12 +313,11 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
}

void handleModelsRequest(Socket& client_socket, HttpRequest& request) {
std::string response = createJsonResponse(
writeJsonResponse(client_socket,
"{ \"object\": \"list\","
"\"data\": ["
"{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }"
"] }");
client_socket.write(response.c_str(), response.length());
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
Expand Down Expand Up @@ -492,7 +360,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,

int main(int argc, char *argv[]) {
initQuants();

AppArgs args = AppArgs::parse(argc, argv, false);
App::run(&args, server);
return EXIT_SUCCESS;
Expand Down
Loading
Loading