Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional chat templates to dllama-api #73

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = 0.9f;
args.steps = 0;
args.seed = (unsigned long long)time(NULL);
args.chat_template = "llama3";

int i = 1;
if (hasMode && argc > 1) {
Expand Down Expand Up @@ -84,6 +85,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = atof(argv[i + 1]);
} else if (strcmp(argv[i], "--seed") == 0) {
args.seed = atoll(argv[i + 1]);
} else if (strcmp(argv[i], "--chat-template") == 0) {
args.chat_template = std::string(argv[i + 1]);
} else {
printf("Unknown option %s\n", argv[i]);
exit(EXIT_FAILURE);
Expand Down
3 changes: 3 additions & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class AppArgs {
// worker
int port;

// API specific
std::string chat_template;

static AppArgs parse(int argc, char** argv, bool hasMode);
};

Expand Down
63 changes: 52 additions & 11 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,60 @@ class Router {
}
};

/*
Generally speaking, the tokenizer.config.json would contain the chat template for the model
and depending on the model used you could set the chat template to follow
could possibly just for simplicity set this in ServerArgs with --chat-template
for this code draft I am assuming the use of llama 3 instruct
*/
std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage> &messages){
std::string buildChatPrompt(Tokenizer *tokenizer, AppArgs* args, const std::vector<ChatMessage> &messages){
std::ostringstream oss;
for (const auto& message : messages) {
oss << "<|start_header_id|>" << message.role << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>";
std::string chat_template = std::string(args->chat_template);

if(chat_template == "llama3"){
for (const auto& message : messages) {
oss << "<|start_header_id|>" << message.role << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>";
}

oss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
else if(chat_template == "llama2"){
for (int i = 0; i < messages.size(); ++i) {
const auto& message = messages[i];
if(i == 0 && message.role == "system"){
oss << "[INST] <<SYS>>\n" << message.content << "\n<</SYS>>\n\n[/INST]</s>";
}
else{
oss << "[INST] " << message.content << " [/INST]";
}
}
}
else if(chat_template == "chatml"){
for (const auto& message : messages) {
oss << "<|im_start|>" << message.role << "\n" << message.content << "<|im_end|>\n";
}
oss << "<|im_start|>assistant\n";
}
else if(chat_template == "openchat"){
for (int i = 0; i < messages.size(); ++i) {
const auto& message = messages[i];
if(i == 0 && message.role == "system"){
oss << message.content << "<|end_of_turn|>";
}
else if(message.role == "user"){
oss << "GPT4 Correct User: " << message.content << "<|end_of_turn|>";
}
else if(message.role == "assistant"){
oss << "GPT4 Correct Assistant: " << message.content << "<|end_of_turn|>";
}
}
}
else if(chat_template == "openchat3"){
auto capitalize = [](const std::string& input) -> std::string {
return input.empty() ? input : std::string(1, std::toupper(input[0])) + input.substr(1);
};
for (int i = 0; i < messages.size(); ++i) {
const auto& message = messages[i];
oss << "<|start_header_id|>" << "GPT4 Correct " << capitalize(message.role) << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>";
}

oss << "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n";
}

oss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
return oss.str();
}

Expand Down Expand Up @@ -205,7 +246,7 @@ void handleCompletionsRequest(HttpRequest& request, Inference* inference, Tokeni
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
params.prompt = buildChatPrompt(tokenizer, args, parseChatMessages(request.parsedJson["messages"]));
params.max_tokens = spec->seqLen - params.prompt.size();

if (request.parsedJson.contains("stream")) {
Expand Down
Loading