-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add to tokenizer chat configuration. (#76)
- Loading branch information
Showing
11 changed files
with
712 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import sys | ||
import json | ||
import os | ||
writer = __import__('tokenizer-writer') | ||
|
||
def openJson(path): | ||
with open(path, 'r', encoding='utf-8') as file: | ||
return json.load(file) | ||
|
||
def printUsage(): | ||
print('Usage: python convert-tokenizer-hf.py <tokenizerFolderPath> <name>') | ||
print() | ||
print('Options:') | ||
print(' <tokenizerFolderPath> The path to the folder with tokenizer.json and tokenizer_config.json') | ||
print(' <name> The name of the tokenizer (e.g. "llama3")') | ||
|
||
if __name__ == '__main__': | ||
if (len(sys.argv) < 2): | ||
printUsage() | ||
exit(1) | ||
|
||
dirPath = sys.argv[1] | ||
name = sys.argv[2] | ||
tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json')) | ||
tokenizer = openJson(os.path.join(dirPath, 'tokenizer.json')) | ||
|
||
assert(tokenizerConfig['tokenizer_class'] == 'PreTrainedTokenizerFast') | ||
assert(tokenizer['model']['type'] == 'BPE') | ||
i = 0 | ||
tokens = [] | ||
scores = [] | ||
bosId = None | ||
eosId = None | ||
for token in tokenizer['model']['vocab'].keys(): | ||
assert(tokenizer['model']['vocab'][token] == i) | ||
tokens.append(token.encode('utf8')) | ||
scores.append(-float(i)) | ||
i += 1 | ||
if ('added_tokens' in tokenizer): | ||
for at in tokenizer['added_tokens']: | ||
assert(at['id'] == i) | ||
tokens.append(at['content'].encode('utf8')) | ||
scores.append(-float(i)) | ||
if (at['content'] == tokenizerConfig['bos_token']): | ||
bosId = i | ||
if (at['content'] == tokenizerConfig['eos_token']): | ||
eosId = i | ||
i += 1 | ||
|
||
templateChat = None | ||
if ('chat_template' in tokenizerConfig): | ||
template = tokenizerConfig['chat_template'] | ||
print('⭐ Found chat template:') | ||
print() | ||
print(template.replace('\n', '\\n')) | ||
print() | ||
print('⭐ To create the tokenizer file you need to manually specify chat template values. Enter \\n for new line.') | ||
templateChat = {} | ||
templateKeys = ['chat_message_start', 'chat_role_start', 'chat_role_end', 'chat_message_end', 'chat_generation_prompt', 'chat_extra_stop'] | ||
for key in templateKeys: | ||
value = input(f'⏩ Enter value for chat template key "{key}":\n') | ||
templateChat[key] = value.replace('\\n', '\n') | ||
|
||
outputFileName = f'dllama_tokenizer_{name}.t' | ||
with open(outputFileName, 'wb') as outputFile: | ||
writer.writeTokenizer(outputFile, { | ||
'bos_id': bosId, | ||
'eos_id': eosId, | ||
'chat_eos_id': eosId, | ||
}, templateChat, tokens, scores) | ||
print(f'✅ Created {outputFileName}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import struct | ||
|
||
def writeTokenizer(file, params, chatTemplate, tokens, scores): | ||
assert(params['eos_id'] is not None) | ||
assert(params['bos_id'] is not None) | ||
|
||
headerKeys = { | ||
'version': 0, | ||
'vocab_size': 1, | ||
'max_token_length': 2, | ||
'bos_id': 3, | ||
'eos_id': 4, | ||
'pad_id': 5, | ||
'chat_eos_id': 6, | ||
'chat_template': 7 | ||
} | ||
header = struct.pack('i', 0x567124) | ||
|
||
nTokens = len(tokens) | ||
maxTokenLength = max(len(t) for t in tokens) | ||
|
||
params['version'] = 0 | ||
params['vocab_size'] = nTokens | ||
params['max_token_length'] = maxTokenLength | ||
if (chatTemplate): | ||
params['chat_template'] = len(chatTemplate) | ||
|
||
data = b'' | ||
for key in params: | ||
if key in headerKeys: | ||
data += struct.pack('ii', headerKeys[key], params[key]) | ||
else: | ||
print(f'Unknown header key: {key}') | ||
|
||
header += struct.pack('i', len(header) * 2 + len(data)) | ||
file.write(header) | ||
file.write(data) | ||
|
||
print(params) | ||
if (chatTemplate): | ||
print(chatTemplate) | ||
|
||
if (chatTemplate): | ||
chatTemplateValue = list(chatTemplate.values()) | ||
nChatTemplates = len(chatTemplateValue) | ||
for i in range(0, nChatTemplates): | ||
file.write(struct.pack('I', len(chatTemplateValue[i].encode('utf8')))) | ||
for i in range(0, nChatTemplates): | ||
data = chatTemplateValue[i].encode('utf8') | ||
if (len(data) > 0): | ||
file.write(data) | ||
|
||
for i in range(0, nTokens): | ||
size = len(tokens[i]) | ||
assert(size > 0) | ||
file.write(struct.pack('fI', scores[i], size)) | ||
file.write(tokens[i]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.