Skip to content

Commit deda6dc

Browse files
Merge pull request #1 from CodeWithKyrian/apply-chat-template
Add ability to use chat templates in Text Generation
2 parents 1fa19ea + a815a89 commit deda6dc

File tree

6 files changed

+62
-14
lines changed

6 files changed

+62
-14
lines changed

composer.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"ankane/onnxruntime": "^0.2.0",
1919
"ext-gd": "*",
2020
"rindow/rindow-math-matrix": "^1.2",
21-
"spatie/fork": "^1.2"
21+
"spatie/fork": "^1.2",
22+
"codewithkyrian/jinja-php": "^1.0"
2223
},
2324
"require-dev": {
2425
"pestphp/pest": "^2.31",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
use Codewithkyrian\Transformers\PretrainedTokenizers\AutoTokenizer;
6+
7+
require_once './bootstrap.php';
8+
9+
$tokenizer = AutoTokenizer::fromPretrained('mistralai/Mistral-7B-Instruct-v0.1');
10+
//$tokenizer = AutoTokenizer::fromPretrained('facebook/blenderbot-400M-distill');
11+
$messages = [
12+
['role' => 'user', 'content' => 'Hello!'],
13+
['role' => 'assistant', 'content' => 'Hi! How are you?'],
14+
['role' => 'user', 'content' => 'I am doing great.'],
15+
['role' => 'assistant', 'content' => 'That is great to hear.'],
16+
];
17+
18+
$text = $tokenizer->applyChatTemplate($messages, addGenerationPrompt: true, tokenize: false);
19+
20+
dd($text);
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
6+
namespace Codewithkyrian\Transformers\Exceptions;
7+
8+
class TemplateParseException extends \Exception implements TransformersException
9+
{
10+
public static function undefinedVariable($variableName): TemplateParseException
11+
{
12+
return new self("Undefined variable: $variableName");
13+
}
14+
}

src/Pipelines/TextGenerationPipeline.php

+15
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@ public function __invoke(array|string $texts, ...$args): array
6666

6767
$generationConfig = new GenerationConfig($snakeCasedArgs);
6868

69+
$isChatMode = $this->isChatMode($texts);
70+
71+
if ($isChatMode) {
72+
$texts = $this->tokenizer->applyChatTemplate($texts, addGenerationPrompt: true, tokenize: false);
73+
}
74+
6975
$isBatched = is_array($texts);
76+
7077
if (!$isBatched) {
7178
$texts = [$texts];
7279
}
@@ -86,6 +93,7 @@ public function __invoke(array|string $texts, ...$args): array
8693

8794
$decoded = $this->tokenizer->batchDecode($outputTokenIds, skipSpecialTokens: true);
8895

96+
8997
$toReturn = array_fill(0, count($texts), []);
9098

9199
for ($i = 0; $i < count($decoded); ++$i) {
@@ -103,4 +111,11 @@ protected function camelCaseToSnakeCase(string $input): string
103111
{
104112
return strtolower(preg_replace('/(?<!^)[A-Z]/', '_$0', $input));
105113
}
114+
115+
// Detect chat mode
116+
protected function isChatMode(string|array $texts): bool
117+
{
118+
return is_array($texts) && isset($texts[0]) && is_array($texts[0]) && !array_is_list($texts[0]);
119+
120+
}
106121
}

src/PretrainedTokenizers/PretrainedTokenizer.php

+11-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
namespace Codewithkyrian\Transformers\PretrainedTokenizers;
77

8+
use Codewithkyrian\Jinja\Template;
89
use Codewithkyrian\Transformers\Decoders\Decoder;
910
use Codewithkyrian\Transformers\Normalizers\Normalizer;
1011
use Codewithkyrian\Transformers\PostProcessors\PostProcessedOutput;
@@ -18,7 +19,7 @@ class PretrainedTokenizer
1819
{
1920
protected bool $returnTokenTypeIds = false;
2021

21-
protected bool $warnedAboutChatTemplate;
22+
protected bool $warnedAboutChatTemplate = false;
2223

2324
protected string $defaultChatTemplate = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}";
2425

@@ -68,7 +69,7 @@ class PretrainedTokenizer
6869
protected bool $legacy;
6970

7071
protected mixed $chatTemplate;
71-
protected \SplObjectStorage $compiledTemplateCache;
72+
protected array $compiledTemplateCache = [];
7273

7374
/**
7475
* @param array $tokenizerJSON The JSON of the tokenizer.
@@ -145,8 +146,6 @@ public function __construct(protected array $tokenizerJSON, protected array $tok
145146
$this->legacy = false;
146147

147148
$this->chatTemplate = $tokenizerConfig['chat_template'] ?? null;
148-
$this->compiledTemplateCache = new \SplObjectStorage();
149-
150149
}
151150

152151
/**
@@ -335,7 +334,7 @@ function ($key) {
335334
if (
336335
array_reduce($encodedTokens, function ($carry, $x) use ($encodedTokens) {
337336
foreach ($x as $key => $value) {
338-
if (count($value) !== count($encodedTokens[0][$key] ?? [])) {
337+
if (count($value ?? []) !== count($encodedTokens[0][$key] ?? [])) {
339338
return true;
340339
}
341340
}
@@ -610,10 +609,10 @@ function truncateHelper(array &$item, int $length): void
610609

611610
protected function getDefaultChatTemplate(): string
612611
{
613-
if (!$this->warnedAboutChatTemplate) {
614-
trigger_error("The default chat template is deprecated and will be removed in a future version. Please use the `chat_template` option instead.", E_USER_WARNING);
615-
$this->warnedAboutChatTemplate = true;
616-
}
612+
// if (!$this->warnedAboutChatTemplate) {
613+
// trigger_error("The default chat template is deprecated and will be removed in a future version. Please use the `chat_template` option instead.", E_USER_WARNING);
614+
// $this->warnedAboutChatTemplate = true;
615+
// }
617616

618617
return $this->defaultChatTemplate;
619618
}
@@ -670,7 +669,7 @@ public function applyChatTemplate(
670669

671670
if ($compiledTemplate === null) {
672671
// TODO: Use Jinja to compile the template
673-
$compiledTemplate = null;
672+
$compiledTemplate = new Template($chatTemplate);
674673
$this->compiledTemplateCache[$chatTemplate] = $compiledTemplate;
675674
}
676675

@@ -695,10 +694,10 @@ public function applyChatTemplate(
695694
addSpecialTokens: false,
696695
truncation: $truncation,
697696
maxLength: $maxLength
698-
)['input_ids'];
697+
)['input_ids']->toArray();
699698
}
700699

701-
return $rendered;
700+
return stripcslashes($rendered);
702701
}
703702

704703
}

src/Utils/Helpers.php

-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ function array_every(array $array, callable $callback): bool
5858
return true;
5959
}
6060

61-
6261
function joinPaths(string ...$args): string
6362
{
6463
$paths = [];

0 commit comments

Comments
 (0)