From 6eccd30204e3fc64da793cbdac30045787b754ec Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Fri, 31 May 2024 22:39:02 +0200 Subject: [PATCH] feat: add to tokenizer chat configuration. (#76) --- .github/workflows/main.yml | 6 + Makefile | 2 + converter/convert-tokenizer-hf.py | 71 +++++++++ converter/convert-tokenizer-llama3.py | 45 +++--- converter/tokenizer-writer.py | 57 ++++++++ src/app.cpp | 2 +- src/apps/dllama-api/dllama-api.cpp | 100 ++++++------- src/apps/dllama/dllama.cpp | 200 +++++++++++++++----------- src/tokenizer-test.cpp | 170 ++++++++++++++++++++++ src/tokenizer.cpp | 194 +++++++++++++++++++++---- src/tokenizer.hpp | 56 +++++++- 11 files changed, 712 insertions(+), 191 deletions(-) create mode 100644 converter/convert-tokenizer-hf.py create mode 100644 converter/tokenizer-writer.py create mode 100644 src/tokenizer-test.cpp diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4b50411..a8b13f6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,6 +30,7 @@ jobs: make dllama-api make funcs-test make quants-test + make tokenizer-test make transformer-test make llama2-tasks-test make grok1-tasks-test @@ -37,6 +38,8 @@ jobs: run: ./funcs-test - name: quants-test run: ./quants-test + - name: tokenizer-test + run: ./tokenizer-test - name: transformer-test run: ./transformer-test - name: llama2-tasks-test @@ -60,6 +63,7 @@ jobs: make dllama-api make funcs-test make quants-test + make tokenizer-test make transformer-test make llama2-tasks-test make grok1-tasks-test @@ -67,6 +71,8 @@ jobs: run: ./funcs-test - name: quants-test run: ./quants-test + - name: tokenizer-test + run: ./tokenizer-test - name: transformer-test run: ./transformer-test - name: llama2-tasks-test diff --git a/Makefile b/Makefile index e938e8f..d8104f7 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,8 @@ funcs-test: src/funcs-test.cpp funcs utils quants $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS) quants-test: src/quants.cpp utils quants $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS) +tokenizer-test: src/tokenizer-test.cpp tokenizer funcs utils quants + $(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o utils.o quants.o $(LIBS) transformer-test: src/transformer-test.cpp funcs utils quants transformer socket $(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS) llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer diff --git a/converter/convert-tokenizer-hf.py b/converter/convert-tokenizer-hf.py new file mode 100644 index 0000000..cb1a3fd --- /dev/null +++ b/converter/convert-tokenizer-hf.py @@ -0,0 +1,71 @@ +import sys +import json +import os +writer = __import__('tokenizer-writer') + +def openJson(path): + with open(path, 'r', encoding='utf-8') as file: + return json.load(file) + +def printUsage(): + print('Usage: python convert-tokenizer-hf.py ') + print() + print('Options:') + print(' The path to the folder with tokenizer.json and tokenizer_config.json') + print(' The name of the tokenizer (e.g. "llama3")') + +if __name__ == '__main__': + if (len(sys.argv) < 2): + printUsage() + exit(1) + + dirPath = sys.argv[1] + name = sys.argv[2] + tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json')) + tokenizer = openJson(os.path.join(dirPath, 'tokenizer.json')) + + assert(tokenizerConfig['tokenizer_class'] == 'PreTrainedTokenizerFast') + assert(tokenizer['model']['type'] == 'BPE') + i = 0 + tokens = [] + scores = [] + bosId = None + eosId = None + for token in tokenizer['model']['vocab'].keys(): + assert(tokenizer['model']['vocab'][token] == i) + tokens.append(token.encode('utf8')) + scores.append(-float(i)) + i += 1 + if ('added_tokens' in tokenizer): + for at in tokenizer['added_tokens']: + assert(at['id'] == i) + tokens.append(at['content'].encode('utf8')) + scores.append(-float(i)) + if (at['content'] == tokenizerConfig['bos_token']): + bosId = i + if (at['content'] == tokenizerConfig['eos_token']): + eosId = i + i += 1 + + templateChat = None + if ('chat_template' in tokenizerConfig): + template = tokenizerConfig['chat_template'] + print('⭐ Found chat template:') + print() + print(template.replace('\n', '\\n')) + print() + print('⭐ To create the tokenizer file you need to manually specify chat template values. Enter \\n for new line.') + templateChat = {} + templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt', 'chat_extra_stop'] + for key in templateKeys: + value = input(f'⏩ Enter value for chat template key "{key}":\n') + templateChat[key] = value.replace('\\n', '\n') + + outputFileName = f'dllama_tokenizer_{name}.t' + with open(outputFileName, 'wb') as outputFile: + writer.writeTokenizer(outputFile, { + 'bos_id': bosId, + 'eos_id': eosId, + 'chat_eos_id': eosId, + }, templateChat, tokens, scores) + print(f'✅ Created {outputFileName}') diff --git a/converter/convert-tokenizer-llama3.py b/converter/convert-tokenizer-llama3.py index c028493..0ad14d5 100644 --- a/converter/convert-tokenizer-llama3.py +++ b/converter/convert-tokenizer-llama3.py @@ -1,6 +1,7 @@ import sys import struct import base64 +writer = __import__('tokenizer-writer') # Format of input file: # ``` @@ -28,16 +29,32 @@ ] bosId = 128000 eosId = 128001 +chatEosId = 128009 +chatTemplate = { + 'chat_message_start': '', + 'chat_role_start': '<|start_header_id|>', + 'chat_role_end': '<|end_header_id|>\n\n', + 'chat_message_end': '<|eot_id|>', + 'chat_generation_prompt': '<|start_header_id|>assistant<|end_header_id|>\n\n', + 'chat_extra_stop': '' +} + +def printUsage(): + print('Usage: python convert-tokenizer-llama3.py ') + print() + print('Options:') + print(' The path to the Llama 3 tokenizer model (tokenizer.model)') if __name__ == '__main__': if (len(sys.argv) < 2): - print('Invalid usage') + printUsage() exit(1) modelPath = sys.argv[1] + outputFileName = 'dllama_tokenizer_llama3.t' with open(modelPath, 'r') as inputFile: - with open('dllama_tokenizer_llama3.t', 'wb') as outputFile: + with open(outputFileName, 'wb') as outputFile: inputLines = inputFile.readlines() nLines = len(inputLines) @@ -58,22 +75,10 @@ scores.append(score) specialTokenIndex += 1 - vocabSize = len(tokens) - maxTokenLength = max(len(t) for t in tokens) - - outputFile.write(struct.pack('IIIiii', - 0x567123, - vocabSize, - maxTokenLength, - bosId, - eosId, - -1)) - - for i in range(0, vocabSize): - outputFile.write(struct.pack('fI', scores[i], len(tokens[i]))) - outputFile.write(tokens[i]) + writer.writeTokenizer(outputFile, { + 'bos_id': bosId, + 'eos_id': eosId, + 'chat_eos_id': chatEosId, + }, chatTemplate, tokens, scores) - print(f'maxTokenLength={maxTokenLength}') - print(f'bosId={bosId}') - print(f'eosId={eosId}') - print(f'vocabSize={vocabSize}') \ No newline at end of file + print(f'✅ Created {outputFileName}') diff --git a/converter/tokenizer-writer.py b/converter/tokenizer-writer.py new file mode 100644 index 0000000..201fe3a --- /dev/null +++ b/converter/tokenizer-writer.py @@ -0,0 +1,57 @@ +import struct + +def writeTokenizer(file, params, chatTemplate, tokens, scores): + assert(params['eos_id'] is not None) + assert(params['bos_id'] is not None) + + headerKeys = { + 'version': 0, + 'vocab_size': 1, + 'max_token_length': 2, + 'bos_id': 3, + 'eos_id': 4, + 'pad_id': 5, + 'chat_eos_id': 6, + 'chat_template': 7 + } + header = struct.pack('i', 0x567124) + + nTokens = len(tokens) + maxTokenLength = max(len(t) for t in tokens) + + params['version'] = 0 + params['vocab_size'] = nTokens + params['max_token_length'] = maxTokenLength + if (chatTemplate): + params['chat_template'] = len(chatTemplate) + + data = b'' + for key in params: + if key in headerKeys: + data += struct.pack('ii', headerKeys[key], params[key]) + else: + print(f'Unknown header key: {key}') + + header += struct.pack('i', len(header) * 2 + len(data)) + file.write(header) + file.write(data) + + print(params) + if (chatTemplate): + print(chatTemplate) + + if (chatTemplate): + chatTemplateValue = list(chatTemplate.values()) + nChatTemplates = len(chatTemplateValue) + for i in range(0, nChatTemplates): + file.write(struct.pack('I', len(chatTemplateValue[i].encode('utf8')))) + for i in range(0, nChatTemplates): + data = chatTemplateValue[i].encode('utf8') + if (len(data) > 0): + file.write(data) + + for i in range(0, nTokens): + size = len(tokens[i]) + assert(size > 0) + file.write(struct.pack('fI', scores[i], size)) + file.write(tokens[i]) diff --git a/src/app.cpp b/src/app.cpp index ca89b37..4ee1a4a 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -113,12 +113,12 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->weightsFloatType, args->bufferFloatType); TransformerArch arch = TransformerArchFactory::create(&spec); + Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize); if (args->steps == 0 || args->steps > spec.seqLen) { args->steps = spec.seqLen; } - Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize); Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool); socketPool->setTurbo(true); diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index a82ec9d..d9887d0 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -165,22 +165,6 @@ class Router { } }; -/* -Generally speaking, the tokenizer.config.json would contain the chat template for the model -and depending on the model used you could set the chat template to follow -could possibly just for simplicity set this in ServerArgs with --chat-template -for this code draft I am assuming the use of llama 3 instruct -*/ -std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector &messages){ - std::ostringstream oss; - for (const auto& message : messages) { - oss << "<|start_header_id|>" << message.role << "<|end_header_id|>\n\n" << message.content << "<|eot_id|>"; - } - - oss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; - return oss.str(); -} - void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, const bool stop){ ChunkChoice choice; if (stop) { @@ -254,20 +238,34 @@ class ApiServer { Sampler* sampler; AppArgs* args; TransformerSpec* spec; + EosDetector* eosDetector; NaiveCache naiveCache; public: - ApiServer( - Inference* inference, - Tokenizer* tokenizer, - Sampler* sampler, - AppArgs* args, - TransformerSpec* spec) { + ApiServer( Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, EosDetector* eosDetector) { this->inference = inference; this->tokenizer = tokenizer; this->sampler = sampler; this->args = args; this->spec = spec; + this->eosDetector = eosDetector; + } + + std::string buildChatPrompt(std::vector messages) { + assert(tokenizer->nChatTemplates == 6); + + std::ostringstream buffer; + for (const auto& message : messages) { + buffer << tokenizer->chatTemplate[0]; // chatMessageStart + buffer << tokenizer->chatTemplate[1]; // chatRoleStart + buffer << message.role; + buffer << tokenizer->chatTemplate[2]; // chatRoleEnd + buffer << message.content; + buffer << tokenizer->chatTemplate[3]; // chatMessageEnd + } + + buffer << tokenizer->chatTemplate[4]; // chatGenerationPrompt + return buffer.str(); } void complete(HttpRequest& request) { @@ -280,7 +278,7 @@ class ApiServer { printf("🔸"); fflush(stdout); - std::string inputPrompt = buildChatPrompt(tokenizer, deltaPrompt); + std::string inputPrompt = buildChatPrompt(deltaPrompt); int promptLength = inputPrompt.size(); int nPromptTokens; int promptTokens[promptLength + 3]; @@ -301,7 +299,6 @@ class ApiServer { request.writeStreamStartChunk(); } - std::string delta; std::string buffer; size_t nStops = params.stop.size(); @@ -316,47 +313,27 @@ class ApiServer { int prevToken = token; token = sampler->sample(logits); - if (token == tokenizer->eosId) { - printf("🔴"); - break; - } - char* piece = tokenizer->decode(prevToken, token); + bool isSafe = isSafePiece(piece); + + EosDetectorType eosType = eosDetector->append(token, isSafe ? piece : ""); if (isSafePiece(piece)) { printf("%s", piece); fflush(stdout); - delta += piece; } - bool maybeEos = false; - size_t deltaSize = delta.size(); - if (nStops > 0 && deltaSize > 0) { - bool eos = false; - for (size_t s = 0; s < nStops; s++) { - size_t stopSize = params.stop[s].size(); - if (params.stop[s].compare(0, deltaSize, delta) == 0) { - if (stopSize <= deltaSize) { - eos = true; - break; - } else { - maybeEos = true; - break; - } - } + if (eosType == NOT_EOS || eosType == EOS) { + char* delta = eosDetector->getDelta(); + if (delta != NULL) { + std::string deltaStr(delta); + if (params.stream) + writeChatCompletionChunk(request, deltaStr, false); + buffer += deltaStr; } - if (eos) { - printf("⭕"); - break; - } - } - - if (!maybeEos) { - if (params.stream) - writeChatCompletionChunk(request, delta, false); - buffer += delta; - delta.clear(); + eosDetector->clear(); } + if (eosType == EOS) break; } } @@ -427,9 +404,18 @@ void handleModelsRequest(HttpRequest& request) { } void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) { + if (tokenizer->chatEosId < 0) { + printf("⛔ 0.8.0 version introduced a new format of the tokenizer that includes chatEosId. Please update your tokenizer.\n"); + exit(EXIT_FAILURE); + } + SocketServer* server = new SocketServer(args->port); + + TokenizerStops stops(tokenizer); + EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength); + ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector); + printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port); - ApiServer api(inference, tokenizer, sampler, args, spec); std::vector routes = { { diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp index 1972e48..a5ee790 100644 --- a/src/apps/dllama/dllama.cpp +++ b/src/apps/dllama/dllama.cpp @@ -4,6 +4,9 @@ #include #include #include +#include +#include + #include "../../utils.hpp" #include "../../socket.hpp" #include "../../transformer.hpp" @@ -12,20 +15,19 @@ #include "../../app.hpp" void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) { - assert(args->prompt != NULL); + if (args->prompt == NULL) + throw std::runtime_error("Prompt is required"); // encode the (string) prompt into tokens sequence int numPromptTokens = 0; - int* promptTokens = (int*)malloc((strlen(args->prompt) + 3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS + int* promptTokens = new int[strlen(args->prompt) + 3]; // +3 for '\0', ?BOS, ?EOS // TODO: this is a hack for Grok1. We should have a more general way to handle this bool addBos = spec->archType != GROK1; tokenizer->encode(args->prompt, promptTokens, &numPromptTokens, addBos, false); - if (numPromptTokens < 1) { - fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); - exit(EXIT_FAILURE); - } + if (numPromptTokens < 1) + throw std::runtime_error("Expected at least 1 prompt token"); // start the main loop long start = 0; // used to time our code, only initialized after first iteration @@ -73,14 +75,14 @@ void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer if (args->benchmark) printf("🔶 G %4ld ms I %4ld ms T %4ld ms S %6ld kB R %6ld kB ", generationTime, inferenceTime, transferTime, sentBytes / 1024, recvBytes / 1024); - safePrintf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + safePrintf(piece); if (args->benchmark) printf("\n"); fflush(stdout); token = next; } - free(promptTokens); + delete[] promptTokens; if (!args->benchmark) printf("\n"); double avgGenerationTime = totalGenerationTime / (double)pos; @@ -91,90 +93,118 @@ void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer printf("Avg transfer time: %.2f ms\n", totalTransferTime / (double)pos); } -void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { - char* cliSystemPrompt = NULL; - char* cliUserPrompt = NULL; - // buffers for reading the system prompt and user prompt from stdin - // you'll notice they are soomewhat haphazardly and unsafely set atm - char systemPrompt[512]; - char userPrompt[512]; - const size_t renderedPromptSize = 1152; - char renderedPrompt[renderedPromptSize]; - int numPromptTokens = 0; - int* promptTokens = (int*)malloc(1152 * sizeof(int)); - int userIdx; - - // start the main loop - int8_t userTurn = 1; // user starts - int next; // will store the next token in the sequence - int token; // stores the current token to feed into the transformer - int prev_token; - pos_t pos = 0; // position in the sequence - while (pos < args->steps) { - // when it is the user's turn to contribute tokens to the dialog... - if (userTurn) { - // get the (optional) system prompt at position 0 - if (pos == 0) { - // at position 0, the user can also contribute a system prompt - if (cliSystemPrompt == NULL) { - // system prompt was not passed in, attempt to get it from stdin - readStdin("💻 Enter system prompt (optional): ", systemPrompt, sizeof(systemPrompt)); - } else { - // system prompt was passed in, use it - strcpy(systemPrompt, cliSystemPrompt); - } - } - // get the user prompt - if (pos == 0 && cliUserPrompt != NULL) { - // user prompt for position 0 was passed in, use it - strcpy(userPrompt, cliUserPrompt); - } else { - // otherwise get user prompt from stdin - readStdin("👱 User: ", userPrompt, sizeof(userPrompt)); - } - // render user/system prompts into the Llama 2 Chat schema - if (pos == 0 && systemPrompt[0] != '\0') { - char systemTemplate[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; - snprintf(renderedPrompt, renderedPromptSize, systemTemplate, systemPrompt, userPrompt); - } else { - char userTemplate[] = "[INST] %s [/INST]"; - snprintf(renderedPrompt, renderedPromptSize, userTemplate, userPrompt); - } - // encode the rendered prompt into tokens - tokenizer->encode(renderedPrompt, promptTokens, &numPromptTokens, true, false); - userIdx = 0; // reset the user index - userTurn = 0; - printf("🤖 Assistant: "); +size_t readStdin(const char* guide, char* buffer, size_t bufsize) { + fflush(stdin); + // read a line from stdin, up to but not including \n + printf("%s", guide); + if (fgets(buffer, bufsize, stdin) != NULL) { + size_t len = strlen(buffer); + if (len > 0 && buffer[len - 1] == '\n') { + buffer[len - 1] = '\0'; // strip newline + len--; } + return len; + } + return 0; +} - // determine the token to pass into the transformer next - if (userIdx < numPromptTokens) { - // if we are still processing the input prompt, force the next prompt token - token = promptTokens[userIdx++]; - } else { - // otherwise use the next token sampled from previous turn - token = next; - } - // EOS token ends the Assistant turn - if (token == tokenizer->eosId) { - userTurn = 1; +class Chat { +private: + Inference* inference; + Tokenizer* tokenizer; + Sampler* sampler; + AppArgs* args; + TransformerSpec* spec; + EosDetector* eosDetector; + +public: + Chat(Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, EosDetector* eosDetector) { + this->inference = inference; + this->tokenizer = tokenizer; + this->sampler = sampler; + this->args = args; + this->spec = spec; + this->eosDetector = eosDetector; + } + + std::string buildMessage(const std::string& role, std::string& message, bool addGenerationPrompt) { + std::ostringstream buffer; + buffer << tokenizer->chatTemplate[0]; // chatMessageStart + buffer << tokenizer->chatTemplate[1]; // chatRoleStart + buffer << role; + buffer << tokenizer->chatTemplate[2]; // chatRoleEnd + buffer << message; + buffer << tokenizer->chatTemplate[3]; // chatMessageEnd + if (addGenerationPrompt) { + buffer << tokenizer->chatTemplate[4]; // chatGenerationPrompt } + return buffer.str(); + } - // forward the transformer to get logits for the next token - float* logits = inference->infer(token, pos); - next = sampler->sample(logits); - pos++; + void chat() { + std::string inputPrompt; + char inputBuffer[2048]; - if (userIdx >= numPromptTokens && next != 2) { - // the Assistant is responding, so print its output - char* piece = tokenizer->decode(token, next); - safePrintf(piece); // same as printf("%s", piece), but skips "unsafe" bytes - fflush(stdout); + size_t sysPromptLength = readStdin("💻 System prompt (optional): ", inputBuffer, sizeof(inputBuffer)); + if (sysPromptLength > 0) { + std::string sysPrompt = inputBuffer; + inputPrompt += buildMessage("system", sysPrompt, false); } - if (next == tokenizer->eosId) { printf("\n"); } + + pos_t pos = 0; + int token; + do { + size_t userPromptLength; + do { + userPromptLength = readStdin("\n👱 User\n> ", inputBuffer, sizeof(inputBuffer)); + } while (userPromptLength == 0); + + std::string userPrompt = inputBuffer; + inputPrompt += buildMessage("user", userPrompt, true); + + int* inputTokens = new int[inputPrompt.size() + 3]; + int nInputTokens; + tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false); + + pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1); + for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) { + inference->infer(inputTokens[i], pos); + token = inputTokens[i + 1]; + } + + printf("\n🤖 Assistant\n"); + + for (; pos < spec->seqLen; pos++) { + int prevToken = token; + float* logits = inference->infer(token, pos); + token = sampler->sample(logits); + char* piece = tokenizer->decode(prevToken, token); + bool isSafe = isSafePiece(piece); + EosDetectorType eosType = eosDetector->append(token, isSafe ? piece : ""); + if (eosType == NOT_EOS || eosType == EOS) { + char* delta = eosDetector->getDelta(); + if (delta != NULL) { + printf("%s", delta); + fflush(stdout); + } + eosDetector->clear(); + } + if (eosType == EOS) break; + } + + inputPrompt.clear(); + } while (pos < spec->seqLen); + + printf("(end of context)\n"); } - printf("\n"); - free(promptTokens); +}; + +void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { + TokenizerStops stops(tokenizer); + EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength); + + Chat chat(inference, tokenizer, sampler, args, spec, &eosDetector); + chat.chat(); } void worker(AppArgs* args) { diff --git a/src/tokenizer-test.cpp b/src/tokenizer-test.cpp new file mode 100644 index 0000000..e1dec39 --- /dev/null +++ b/src/tokenizer-test.cpp @@ -0,0 +1,170 @@ +#include +#include +#include +#include "tokenizer.hpp" + +#define ASSERT_EOS_TYPE(type, expected) \ + if (type != expected) { \ + printf("Expected %d, got %d (line: %d)\n", expected, type, __LINE__); \ + exit(1); \ + } + +#define EOS_ID 10000 + +void testEosDetectorWithPadding() { + const char* stops[2] = { "", "" }; + EosDetector detector(EOS_ID, 2, stops, 1, 1); + + // "" + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); + assert(detector.getDelta() == NULL); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "stop"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); + assert(detector.getDelta() == NULL); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " ") == 0); + } + + // "! " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "!<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eos"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "!") == 0); + } + + // "! " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "XY"), NOT_EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "XY") == 0); + } + + // ""), EOS); + + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, ""), EOS); + assert(detector.getDelta() == NULL); + } + + printf("✅ EosDetector with padding\n"); +} + + +void testEosDetectorWithLongPadding() { + const char* stops[1] = { "|end|" }; + EosDetector detector(EOS_ID, 1, stops, 5, 5); + + // "lipsum" + { + ASSERT_EOS_TYPE(detector.append(1, "lipsum"), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "lipsum") == 0); + } + + // "lorem" + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "lorem"), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "lorem") == 0); + } + + // "lorem|enQ" + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, "lorem|"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "enQ"), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, "lorem|enQ") == 0); + } + + printf("✅ EosDetector with long padding\n"); +} + +void testEosDetectorWithoutPadding() { + const char* stops[1] = { "" }; + EosDetector detector(EOS_ID, 1, stops, 0, 0); + + // "" + { + ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); + ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); + assert(detector.getDelta() == NULL); + } + + // " <" + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " <"), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " <") == 0); + } + + // " " + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); + char* delta = detector.getDelta(); + assert(delta != NULL); + assert(strcmp(delta, " ") == 0); + } + + // EOS + detector.clear(); + { + ASSERT_EOS_TYPE(detector.append(EOS_ID, ""), EOS); + assert(detector.getDelta() == NULL); + } + + printf("✅ EosDetector without padding\n"); +} + +int main() { + testEosDetectorWithPadding(); + testEosDetectorWithLongPadding(); + testEosDetectorWithoutPadding(); + return EXIT_SUCCESS; +} diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 8cdaae9..3d189de 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -35,26 +35,84 @@ void safePrintf(char *piece) { } } -Tokenizer::Tokenizer(char* tokenizerPath, int vocabSize) { - // i should have written the vocab_size into the tokenizer file... sigh - this->vocabSize = vocabSize; +Tokenizer::Tokenizer(char* tokenizerPath, int modelVocabSize) { + eosId = -1; + bosId = -1; + chatEosId = -1; + nChatTemplates = 0; // read in the file FILE *file = fopen(tokenizerPath, "rb"); if (!file) throw std::runtime_error("Failed to open tokenizer file"); - TokenizerHeader header; - if (fread(&header, sizeof(TokenizerHeader), 1, file) != 1) - throw std::runtime_error("Cannot read tokenizer header"); + int magic; + if (fread(&magic, sizeof(int), 1, file) != 1) + throw std::runtime_error("Cannot read tokenizer magic number"); + + if (magic == 0x567123) { + TokenizerOldHeader header; + if (fread(&header, sizeof(TokenizerOldHeader), 1, file) != 1) + throw std::runtime_error("Cannot read tokenizer header"); + maxTokenLength = header.maxTokenLength; + vocabSize = header.vocabSize; + bosId = header.bosId; + eosId = header.eosId; + } else if (magic == 0x567124) { + TransformerHeaderKey key; + int headerSize; + if (fread(&headerSize, sizeof(int), 1, file) != 1) + throw std::runtime_error("Cannot read tokenizer header size"); + int nKv = (headerSize - 2 * sizeof(int)) / sizeof(int); + int buffer[nKv]; + if (fread(&buffer, nKv * sizeof(int), 1, file) != 1) { + throw std::runtime_error("Cannot read header values"); + } + int version = -1; + for (int i = 0; i < nKv; i += 2) { + int key = buffer[i]; + int value = buffer[i + 1]; + if (key == TOK_VERSION) version = value; + else if (key == TOK_VOCAB_SIZE) vocabSize = value; + else if (key == MAX_TOKEN_LENGTH) maxTokenLength = (unsigned int)value; + else if (key == BOS_ID) bosId = value; + else if (key == EOS_ID) eosId = value; + else if (key == CHAT_EOS_ID) chatEosId = value; + else if (key == CHAT_TEMPLATE) nChatTemplates = value; + else if (key == PAD_ID) {} // ignore + else { + throw std::runtime_error("Invalid tokenizer header key"); + } + } + assert(version == 0); - if (header.magic != 0x567123 || header.vocabSize != vocabSize) + if (nChatTemplates > 0) { + assert(nChatTemplates == 6); + unsigned int templateSizes[nChatTemplates]; + if (fread(&templateSizes, sizeof(templateSizes), 1, file) != 1) { + throw std::runtime_error("Cannot read chat template sizes"); + } + chatTemplate = new char*[nChatTemplates]; + for (int t = 0; t < nChatTemplates; t++) { + chatTemplate[t] = new char[templateSizes[t] + 1]; + if (templateSizes[t] == 0) { + chatTemplate[t][0] = '\0'; + } else if (fread(chatTemplate[t], templateSizes[t], 1, file) != 1) { + throw std::runtime_error("Cannot read chat template"); + } + printf("📄 chatTemplate[%d]: %s\n", t, chatTemplate[t]); + } + } + } else { throw std::runtime_error("Invalid tokenizer file"); + } + + if (maxTokenLength < 1 || vocabSize != modelVocabSize) { + throw std::runtime_error("Tokenizer file is invalid or incompatible with model"); + } - maxTokenLength = header.maxTokenLength; - bosId = header.bosId; - eosId = header.eosId; if (bosId >= 0) printf("📄 bosId: %d\n", bosId); if (eosId >= 0) printf("📄 eosId: %d\n", eosId); + if (chatEosId >= 0) printf("📄 chatEosId: %d\n", chatEosId); // malloc space to hold the scores and the strings vocab = (char**)malloc(vocabSize * sizeof(char*)); @@ -80,6 +138,10 @@ Tokenizer::Tokenizer(char* tokenizerPath, int vocabSize) { } Tokenizer::~Tokenizer() { + if (nChatTemplates > 0) { + for (int t = 0; t < nChatTemplates; t++) delete[] chatTemplate[t]; + delete[] chatTemplate; + } for (int i = 0; i < vocabSize; i++) { free(vocab[i]); } free(vocab); free(vocabScores); @@ -140,7 +202,9 @@ void Tokenizer::encode(char *text, int *tokens, int *nTokens, bool addBos, bool if (text[0] != '\0') { char space[] = " "; int dummy_prefix = str_lookup(space, sortedVocab, vocabSize); - tokens[(*nTokens)++] = dummy_prefix; + // TODO: this condition saves us from segmentation fault + if (dummy_prefix != -1) + tokens[(*nTokens)++] = dummy_prefix; } // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -305,18 +369,6 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, f return probindex[last_idx].index; // in case of rounding errors } -void readStdin(const char* guide, char* buffer, size_t bufsize) { - fflush(stdin); - // read a line from stdin, up to but not including \n - printf("%s", guide); - if (fgets(buffer, bufsize, stdin) != NULL) { - size_t len = strlen(buffer); - if (len > 0 && buffer[len - 1] == '\n') { - buffer[len - 1] = '\0'; // strip newline - } - } -} - Sampler::Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed) { this->vocab_size = vocab_size; this->temperature = temperature; @@ -361,4 +413,98 @@ void Sampler::setTemp(float temp) { void Sampler::setSeed(unsigned long long seed) { this->rngState = seed; -} \ No newline at end of file +} + +TokenizerStops::TokenizerStops(Tokenizer* tokenizer) { + assert(tokenizer->nChatTemplates == 6); + const bool hasExtraStop = tokenizer->chatTemplate[5][0] != '\0'; + nStops = hasExtraStop ? 2 : 1; + char** s = new char*[nStops]; + s[0] = tokenizer->vocab[tokenizer->chatEosId]; + if (hasExtraStop) + s[1] = tokenizer->chatTemplate[5]; + maxStopLength = 0; + for (size_t i = 0; i < nStops; i++) { + size_t len = strlen(s[i]); + if (len > maxStopLength) maxStopLength = len; + } + stops = (const char**)s; +} + +TokenizerStops::~TokenizerStops() { + delete[] stops; +} + +EosDetector::EosDetector(int eosId, size_t nStops, const char** stops, int paddingLeft, int paddingRight) { + this->eosId = eosId; + this->nStops = nStops; + this->stops = stops; + this->stopSizes = new size_t[nStops]; + for (size_t s = 0; s < nStops; s++) { + stopSizes[s] = strlen(stops[s]); + printf("🛑 stop: %s\n", stops[s]); + } + this->bufferPos = 0; + this->bufferSize = 0; + this->paddingLeft = paddingLeft; + this->paddingRight = paddingRight; +} + +EosDetector::~EosDetector() { + if (bufferSize > 0) + delete[] buffer; + delete[] stopSizes; +} + +EosDetectorType EosDetector::append(int tokenId, const char* piece) { + int pieceLength = strlen(piece); + int length = bufferPos + pieceLength + 1; + if (length > bufferSize) { + char* newBuffer = new char[length]; + if (bufferPos > 0) + memcpy(newBuffer, buffer, bufferPos); + if (bufferSize > 0) + delete[] buffer; + buffer = newBuffer; + } + memcpy(buffer + bufferPos, piece, pieceLength + 1); + bufferPos += pieceLength; + + // detection + + if (tokenId == eosId) { + eosPos = bufferPos - pieceLength; + return EOS; + } + eosPos = -1; + + for (size_t s = 0; s < nStops; s++) { + size_t stopSize = stopSizes[s]; + if (bufferPos > stopSize + paddingLeft + paddingRight) continue; + + for (int lo = 0; lo <= paddingLeft; lo++) { + int n = bufferPos - lo; + if (n == 0 || n > stopSize + paddingRight) continue; + if (n > stopSize) n = stopSize; + if (strncmp(buffer + lo, stops[s], n) == 0) { + if (n == stopSize) { + eosPos = lo; + return EOS; + } + return MAYBE_EOS; + } + } + } + return NOT_EOS; +} + +char* EosDetector::getDelta() { + if (eosPos == -1) return buffer; + if (eosPos == 0) return NULL; + buffer[eosPos] = '\0'; + return buffer; +} + +void EosDetector::clear() { + bufferPos = 0; +} diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index 879480e..5c73ec1 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -6,15 +6,13 @@ bool isSafePiece(char *piece); void safePrintf(char *piece); -void readStdin(const char* guide, char* buffer, size_t bufsize); typedef struct { char *str; int id; } TokenIndex; -struct TokenizerHeader { - unsigned int magic; +struct TokenizerOldHeader { unsigned int vocabSize; unsigned int maxTokenLength; int bosId; @@ -22,18 +20,32 @@ struct TokenizerHeader { int padId; }; +enum TokenizerHeaderKey { + TOK_VERSION = 0, + TOK_VOCAB_SIZE = 1, + MAX_TOKEN_LENGTH = 2, + BOS_ID = 3, + EOS_ID = 4, + PAD_ID = 5, + CHAT_EOS_ID = 6, + CHAT_TEMPLATE = 7, +}; + class Tokenizer { private: unsigned int maxTokenLength; - char** vocab; float* vocabScores; TokenIndex *sortedVocab; int vocabSize; unsigned char bytePieces[512]; // stores all single-byte strings public: + char** vocab; int bosId; int eosId; + int chatEosId; + int nChatTemplates; + char** chatTemplate; Tokenizer(char* tokenizer_path, int vocab_size); ~Tokenizer(); @@ -65,4 +77,40 @@ class Sampler { void setSeed(unsigned long long rngSeed); }; +class TokenizerStops { +public: + const char** stops; + size_t nStops; + size_t maxStopLength; + TokenizerStops(Tokenizer* tokenizer); + ~TokenizerStops(); +}; + +enum EosDetectorType { + MAYBE_EOS = 0, + EOS = 1, + NOT_EOS = 2, +}; + +class EosDetector { +private: + int eosId; + size_t nStops; + const char** stops; + size_t* stopSizes; + size_t bufferPos; + size_t bufferSize; + int eosPos; + int paddingLeft; + int paddingRight; +public: + char* buffer; + EosDetector(int eosId, size_t nStops, const char** stops, int paddingLeft, int paddingRight); + ~EosDetector(); + + EosDetectorType append(int tokenId, const char* piece); + char* getDelta(); + void clear(); +}; + #endif