Skip to content

Commit

Permalink
eos detector.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed May 31, 2024
1 parent 5646e14 commit 72f8c3b
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 48 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ jobs:
make dllama-api
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: llama2-tasks-test
Expand All @@ -60,13 +63,16 @@ jobs:
make dllama-api
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: llama2-tasks-test
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ funcs-test: src/funcs-test.cpp funcs utils quants
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS)
quants-test: src/quants.cpp utils quants
$(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS)
tokenizer-test: src/tokenizer-test.cpp tokenizer funcs utils quants
$(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o utils.o quants.o $(LIBS)
transformer-test: src/transformer-test.cpp funcs utils quants transformer socket
$(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer
Expand Down
2 changes: 1 addition & 1 deletion converter/convert-tokenizer-hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def printUsage():
print()
print('⭐ To create the tokenizer file you need to manually specify chat template values. Enter \\n for new line.')
templateChat = {}
templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt']
templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt', 'chat_extra_stop']
for key in templateKeys:
value = input(f'⏩ Enter value for chat template key "{key}":\n')
templateChat[key] = value.replace('\\n', '\n')
Expand Down
1 change: 1 addition & 0 deletions converter/convert-tokenizer-llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'chat_role_end': '<|end_header_id|>\n\n',
'chat_message_end': '<|eot_id|>',
'chat_generation_prompt': '<|start_header_id|>assistant<|end_header_id|>\n\n',
'chat_extra_stop': ''
}

if __name__ == '__main__':
Expand Down
77 changes: 32 additions & 45 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,30 +238,21 @@ class ApiServer {
Sampler* sampler;
AppArgs* args;
TransformerSpec* spec;
EosDetector* eosDetector;
NaiveCache naiveCache;

int eosId;
std::string eos;

public:
ApiServer(
Inference* inference,
Tokenizer* tokenizer,
Sampler* sampler,
AppArgs* args,
TransformerSpec* spec) {
ApiServer( Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, EosDetector* eosDetector) {
this->inference = inference;
this->tokenizer = tokenizer;
this->sampler = sampler;
this->args = args;
this->spec = spec;
eosId = (tokenizer->chatEosId >= 0) ? tokenizer->chatEosId : tokenizer->eosId;
assert(eosId >= 0);
eos = tokenizer->vocab[eosId];
this->eosDetector = eosDetector;
}

std::string buildChatPrompt(std::vector<ChatMessage> messages) {
assert(tokenizer->nChatTemplates == 5);
assert(tokenizer->nChatTemplates == 6);

std::ostringstream buffer;
for (const auto& message : messages) {
Expand Down Expand Up @@ -308,7 +299,6 @@ class ApiServer {
request.writeStreamStartChunk();
}

std::string delta;
std::string buffer;
size_t nStops = params.stop.size();

Expand All @@ -323,45 +313,27 @@ class ApiServer {
int prevToken = token;
token = sampler->sample(logits);

if (token == eosId) {
printf("🔴");
break;
}

char* piece = tokenizer->decode(prevToken, token);
bool isSafe = isSafePiece(piece);

int eosType = eosDetector->append(token, isSafe ? piece : "");

if (isSafePiece(piece)) {
printf("%s", piece);
fflush(stdout);
delta += piece;
}

bool maybeEos = false;
size_t deltaSize = delta.size();
if (nStops > 0 && deltaSize > 0) {
bool isEos = false;
size_t eosSize = eos.size();
if (eos.compare(0, deltaSize, delta) == 0) {
if (eosSize <= deltaSize) {
isEos = true;
break;
} else {
maybeEos = true;
break;
}
}
if (isEos) {
printf("");
break;
if (eosType == NOT_EOS || eosType == EOS) {
char* delta = eosDetector->getDelta();
if (delta != NULL) {
std::string deltaStr(delta);
if (params.stream)
writeChatCompletionChunk(request, deltaStr, false);
buffer += deltaStr;
}
eosDetector->clear();
}

if (!maybeEos) {
if (params.stream)
writeChatCompletionChunk(request, delta, false);
buffer += delta;
delta.clear();
}
if (eosType == EOS) break;
}
}

Expand Down Expand Up @@ -432,9 +404,24 @@ void handleModelsRequest(HttpRequest& request) {
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
if (tokenizer->chatEosId < 0) {
printf("⛔ 0.8.0 version introduced a new format of the tokenizer that includes chatEosId. Please update your tokenizer.\n");
exit(EXIT_FAILURE);
}

SocketServer* server = new SocketServer(args->port);

const bool hasExtraStop = tokenizer->chatTemplate[5][0] != '\0';
const int nStops = hasExtraStop ? 2 : 1;
char* stops[nStops];
stops[0] = tokenizer->vocab[tokenizer->chatEosId];
if (hasExtraStop)
stops[1] = tokenizer->chatTemplate[5];

EosDetector eosDetector(tokenizer->chatEosId, nStops, (const char**)stops, 1, 1);
ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector);

printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port);
ApiServer api(inference, tokenizer, sampler, args, spec);

std::vector<Route> routes = {
{
Expand Down
132 changes: 132 additions & 0 deletions src/tokenizer-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "tokenizer.hpp"

#define ASSERT_EOS_TYPE(type, expected) \
if (type != expected) { \
printf("Expected %d, got %d (line: %d)\n", expected, type, __LINE__); \
exit(1); \
}

#define EOS_ID 10000

void testEosDetectorWithPadding() {
const char* stops[2] = { "<eos>", "<stop>" };
EosDetector detector(EOS_ID, 2, stops, 1, 1);

// "<eos>"
{
ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == NULL);
}

// "<stop> "
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "stop"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(3, "> "), EOS);
assert(detector.getDelta() == NULL);
}

// " "
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS);

char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, " ") == 0);
}

// "!<eos> "
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, "!<"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "eos"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(3, "> "), EOS);

char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, "!") == 0);
}

// "!<eos> "
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "s>XY"), NOT_EOS);

char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, "<eos>XY") == 0);
}

// "<eo" + EOS
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(EOS_ID, "<eos>"), EOS);

char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, "<eo") == 0);
}

// EOS
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(EOS_ID, "<eos>"), EOS);
assert(detector.getDelta() == NULL);
}

printf("✅ EosDetector with padding\n");
}


void testEosDetectorWithoutPadding() {
const char* stops[1] = { "<eos>" };
EosDetector detector(EOS_ID, 1, stops, 0, 0);

// "<eos>"
{
ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == NULL);
}

// " <"
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, " <"), NOT_EOS);
char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, " <") == 0);
}

// "<eos> "
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(1, "<eos"), MAYBE_EOS);
ASSERT_EOS_TYPE(detector.append(2, "> "), NOT_EOS);
char* delta = detector.getDelta();
assert(delta != NULL);
assert(strcmp(delta, "<eos> ") == 0);
}

// EOS
detector.clear();
{
ASSERT_EOS_TYPE(detector.append(EOS_ID, "<eos>"), EOS);
assert(detector.getDelta() == NULL);
}

printf("✅ EosDetector without padding\n");
}

int main() {
testEosDetectorWithPadding();
testEosDetectorWithoutPadding();
return EXIT_SUCCESS;
}
Loading

0 comments on commit 72f8c3b

Please sign in to comment.