From 72f8c3b21912f98003a867098b342df1607f2426 Mon Sep 17 00:00:00 2001 From: b4rtaz Date: Fri, 31 May 2024 16:47:56 +0200 Subject: [PATCH] eos detector. --- .github/workflows/main.yml | 6 ++ Makefile | 2 + converter/convert-tokenizer-hf.py | 2 +- converter/convert-tokenizer-llama3.py | 1 + src/apps/dllama-api/dllama-api.cpp | 77 +++++++-------- src/tokenizer-test.cpp | 132 ++++++++++++++++++++++++++ src/tokenizer.cpp | 77 ++++++++++++++- src/tokenizer.hpp | 29 +++++- 8 files changed, 278 insertions(+), 48 deletions(-) create mode 100644 src/tokenizer-test.cpp diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4b50411..a8b13f6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,6 +30,7 @@ jobs: make dllama-api make funcs-test make quants-test + make tokenizer-test make transformer-test make llama2-tasks-test make grok1-tasks-test @@ -37,6 +38,8 @@ jobs: run: ./funcs-test - name: quants-test run: ./quants-test + - name: tokenizer-test + run: ./tokenizer-test - name: transformer-test run: ./transformer-test - name: llama2-tasks-test @@ -60,6 +63,7 @@ jobs: make dllama-api make funcs-test make quants-test + make tokenizer-test make transformer-test make llama2-tasks-test make grok1-tasks-test @@ -67,6 +71,8 @@ jobs: run: ./funcs-test - name: quants-test run: ./quants-test + - name: tokenizer-test + run: ./tokenizer-test - name: transformer-test run: ./transformer-test - name: llama2-tasks-test diff --git a/Makefile b/Makefile index e938e8f..d8104f7 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,8 @@ funcs-test: src/funcs-test.cpp funcs utils quants $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS) quants-test: src/quants.cpp utils quants $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS) +tokenizer-test: src/tokenizer-test.cpp tokenizer funcs utils quants + $(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o utils.o quants.o $(LIBS) transformer-test: src/transformer-test.cpp funcs utils quants transformer socket $(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS) llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer diff --git a/converter/convert-tokenizer-hf.py b/converter/convert-tokenizer-hf.py index 7fdb4e4..cb1a3fd 100644 --- a/converter/convert-tokenizer-hf.py +++ b/converter/convert-tokenizer-hf.py @@ -56,7 +56,7 @@ def printUsage(): print() print('⭐ To create the tokenizer file you need to manually specify chat template values. Enter \\n for new line.') templateChat = {} - templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt'] + templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt', 'chat_extra_stop'] for key in templateKeys: value = input(f'⏩ Enter value for chat template key "{key}":\n') templateChat[key] = value.replace('\\n', '\n') diff --git a/converter/convert-tokenizer-llama3.py b/converter/convert-tokenizer-llama3.py index b861430..fdf0d65 100644 --- a/converter/convert-tokenizer-llama3.py +++ b/converter/convert-tokenizer-llama3.py @@ -36,6 +36,7 @@ 'chat_role_end': '<|end_header_id|>\n\n', 'chat_message_end': '<|eot_id|>', 'chat_generation_prompt': '<|start_header_id|>assistant<|end_header_id|>\n\n', + 'chat_extra_stop': '' } if __name__ == '__main__': diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index 23ffa90..0c6d211 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -238,30 +238,21 @@ class ApiServer { Sampler* sampler; AppArgs* args; TransformerSpec* spec; + EosDetector* eosDetector; NaiveCache naiveCache; - int eosId; - std::string eos; - public: - ApiServer( - Inference* inference, - Tokenizer* tokenizer, - Sampler* sampler, - AppArgs* args, - TransformerSpec* spec) { + ApiServer( Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, EosDetector* eosDetector) { this->inference = inference; this->tokenizer = tokenizer; this->sampler = sampler; this->args = args; this->spec = spec; - eosId = (tokenizer->chatEosId >= 0) ? tokenizer->chatEosId : tokenizer->eosId; - assert(eosId >= 0); - eos = tokenizer->vocab[eosId]; + this->eosDetector = eosDetector; } std::string buildChatPrompt(std::vector messages) { - assert(tokenizer->nChatTemplates == 5); + assert(tokenizer->nChatTemplates == 6); std::ostringstream buffer; for (const auto& message : messages) { @@ -308,7 +299,6 @@ class ApiServer { request.writeStreamStartChunk(); } - std::string delta; std::string buffer; size_t nStops = params.stop.size(); @@ -323,45 +313,27 @@ class ApiServer { int prevToken = token; token = sampler->sample(logits); - if (token == eosId) { - printf("🔴"); - break; - } - char* piece = tokenizer->decode(prevToken, token); + bool isSafe = isSafePiece(piece); + + int eosType = eosDetector->append(token, isSafe ? piece : ""); if (isSafePiece(piece)) { printf("%s", piece); fflush(stdout); - delta += piece; } - bool maybeEos = false; - size_t deltaSize = delta.size(); - if (nStops > 0 && deltaSize > 0) { - bool isEos = false; - size_t eosSize = eos.size(); - if (eos.compare(0, deltaSize, delta) == 0) { - if (eosSize <= deltaSize) { - isEos = true; - break; - } else { - maybeEos = true; - break; - } - } - if (isEos) { - printf("⭕"); - break; + if (eosType == NOT_EOS || eosType == EOS) { + char* delta = eosDetector->getDelta(); + if (delta != NULL) { + std::string deltaStr(delta); + if (params.stream) + writeChatCompletionChunk(request, deltaStr, false); + buffer += deltaStr; } + eosDetector->clear(); } - - if (!maybeEos) { - if (params.stream) - writeChatCompletionChunk(request, delta, false); - buffer += delta; - delta.clear(); - } + if (eosType == EOS) break; } } @@ -432,9 +404,24 @@ void handleModelsRequest(HttpRequest& request) { } void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) { + if (tokenizer->chatEosId < 0) { + printf("⛔ 0.8.0 version introduced a new format of the tokenizer that includes chatEosId. Please update your tokenizer.\n"); + exit(EXIT_FAILURE); + } + SocketServer* server = new SocketServer(args->port); + + const bool hasExtraStop = tokenizer->chatTemplate[5][0] != '\0'; + const int nStops = hasExtraStop ? 2 : 1; + char* stops[nStops]; + stops[0] = tokenizer->vocab[tokenizer->chatEosId]; + if (hasExtraStop) + stops[1] = tokenizer->chatTemplate[5]; + + EosDetector eosDetector(tokenizer->chatEosId, nStops, (const char**)stops, 1, 1); + ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector); + printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port); - ApiServer api(inference, tokenizer, sampler, args, spec); std::vector routes = { { diff --git a/src/tokenizer-test.cpp b/src/tokenizer-test.cpp new file mode 100644 index 0000000..4785bf8 --- /dev/null +++ b/src/tokenizer-test.cpp @@ -0,0 +1,132 @@ +#include "tokenizer.hpp" + +#define ASSERT_EOS_TYPE(type, expected) \ + if (type != expected) { \ + printf("Expected %d, got %d (line: %d)\n", expected, type, __LINE__); \ + exit(1); \ + } + +#define EOS_ID 10000 + +void testEosDetectorWithPadding() { + const char* stops[2] = { "", "" }; + EosDetector detector(EOS_ID, 2, stops, 1, 1); + + // "" + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); + assert(detector.getDelta() == NULL); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "stop"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); + assert(detector.getDelta() == NULL); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " ") == 0); + } + + // "! " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "!<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eos"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "!") == 0); + } + + // "! " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "XY"), NOT_EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "XY") == 0); + } + + // ""), EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, ""), EOS); + assert(detector.getDelta() == NULL); + } + + printf("✅ EosDetector with padding\n"); +} + + +void testEosDetectorWithoutPadding() { + const char* stops[1] = { "" }; + EosDetector detector(EOS_ID, 1, stops, 0, 0); + + // "" + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); + assert(detector.getDelta() == NULL); + } + + // " <" + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " <"), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " <") == 0); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " ") == 0); + } + + // EOS + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(EOS_ID, ""), EOS); + assert(detector.getDelta() == NULL); + } + + printf("✅ EosDetector without padding\n"); +} + +int main() { + testEosDetectorWithPadding(); + testEosDetectorWithoutPadding(); + return EXIT_SUCCESS; +} diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index b041f6d..8312e26 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -86,6 +86,7 @@ Tokenizer::Tokenizer(char* tokenizerPath, int modelVocabSize) { assert(version == 0); if (nChatTemplates > 0) { + assert(nChatTemplates == 6); unsigned int templateSizes[nChatTemplates]; if (fread(&templateSizes, sizeof(templateSizes), 1, file) != 1) { throw std::runtime_error("Cannot read chat template sizes"); @@ -424,4 +425,78 @@ void Sampler::setTemp(float temp) { void Sampler::setSeed(unsigned long long seed) { this->rngState = seed; -} \ No newline at end of file +} + +EosDetector::EosDetector(int eosId, size_t nStops, const char** stops, int paddingLeft, int paddingRight) { + this->eosId = eosId; + this->nStops = nStops; + this->stops = stops; + this->stopSizes = new size_t[nStops]; + for (size_t s = 0; s < nStops; s++) { + stopSizes[s] = strlen(stops[s]); + printf("🛑 stop: %s\n", stops[s]); + } + this->bufferPos = 0; + this->bufferSize = 0; + this->paddingLeft = paddingLeft; + this->paddingRight = paddingRight; +} + +EosDetector::~EosDetector() { + if (bufferSize > 0) + delete[] buffer; + delete[] stopSizes; +} + +EosDetectorType EosDetector::append(int tokenId, const char* piece) { + int pieceLength = strlen(piece); + int length = bufferPos + pieceLength + 1; + if (length > bufferSize) { + char* newBuffer = new char[length]; + if (bufferPos > 0) + memcpy(newBuffer, buffer, bufferPos); + if (bufferSize > 0) + delete[] buffer; + buffer = newBuffer; + } + memcpy(buffer + bufferPos, piece, pieceLength + 1); + bufferPos += pieceLength; + + // detection + + if (tokenId == eosId) { + eosPos = bufferPos - pieceLength; + return EOS; + } + eosPos = -1; + + for (size_t s = 0; s < nStops; s++) { + size_t stopSize = stopSizes[s]; + if (bufferPos > stopSize + paddingLeft + paddingRight) continue; + + for (int lo = 0; lo <= paddingLeft; lo++) { + int n = bufferPos - lo; + if (n == 0 || n > stopSize + paddingRight) continue; + if (n > stopSize) n = stopSize; + if (strncmp(buffer + lo, stops[s], n) == 0) { + if (n == stopSize) { + eosPos = lo; + return EOS; + } + return MAYBE_EOS; + } + } + } + return NOT_EOS; +} + +char* EosDetector::getDelta() { + if (eosPos == -1) return buffer; + if (eosPos == 0) return NULL; + buffer[eosPos] = '\0'; + return buffer; +} + +void EosDetector::clear() { + bufferPos = 0; +} diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index dd3667c..14f4165 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -44,7 +44,7 @@ class Tokenizer { char** vocab; int bosId; int eosId; - int chatEosId; // -1 if not used + int chatEosId; int nChatTemplates; char** chatTemplate; @@ -78,4 +78,31 @@ class Sampler { void setSeed(unsigned long long rngSeed); }; +enum EosDetectorType { + MAYBE_EOS = 0, + EOS = 1, + NOT_EOS = 2, +}; + +class EosDetector { +private: + int eosId; + size_t nStops; + const char** stops; + size_t* stopSizes; + size_t bufferPos; + size_t bufferSize; + int eosPos; + int paddingLeft; + int paddingRight; +public: + char* buffer; + EosDetector(int eosId, size_t nStops, const char** stops, int paddingLeft, int paddingRight); + ~EosDetector(); + + EosDetectorType append(int tokenId, const char* piece); + char* getDelta(); + void clear(); +}; + #endif