Skip to content

Commit 89ab0f1

Browse files
feat: Add comprehensive tests for tokenizers and new pretrained tokenizers (#74)
1 parent 134c002 commit 89ab0f1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+30093
-575
lines changed

examples/bootstrap.php

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77

88
require_once './vendor/autoload.php';
99

10-
Transformers::setup()->setImageDriver(ImageDriver::VIPS);
11-
10+
Transformers::setup()
11+
->setCacheDir('/Users/Kyrian/.transformers')
12+
->setImageDriver(ImageDriver::VIPS);

examples/pipelines/asr.php

+7-4
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@
2222
//$audioUrl = __DIR__ . '/../sounds/taunt.wav';
2323
//$audioUrl = __DIR__ . '/../sounds/gettysburg.wav';
2424
//$audioUrl = __DIR__ . '/../sounds/kyrian-speaking.wav';
25-
//$audioUrl = __DIR__ . '/../sounds/ted_60.wav';
26-
$audioUrl = __DIR__ . '/../sounds/sample-1.mp3';
25+
$audioUrl = __DIR__ . '/../sounds/ted_60.wav';
26+
//$audioUrl = __DIR__ . '/../sounds/sample-1.mp3';
2727

28+
$streamer = WhisperTextStreamer::make()
29+
//->onTimestampStart(fn($timestamp) => dump($timestamp));
30+
->onStream(fn($text) => print($text));
2831

2932
$output = $transcriber($audioUrl,
3033
maxNewTokens: 256,
3134
chunkLengthSecs: 24,
32-
// returnTimestamps: 'word',
35+
streamer: $streamer,
3336
);
3437

35-
dd($output, timeUsage(), memoryUsage());
38+
dd($output, timeUsage(), memoryUsage());

examples/pipelines/text-classification.php

+3-7
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77
require_once './bootstrap.php';
88

99

10-
//$classifier = pipeline('text-classification', 'Xenova/toxic-bert');
10+
$classifier = pipeline('text-classification', 'Xenova/toxic-bert');
1111
//
1212
//$result = $classifier("I hate you! You gave me life but in misery", topK: -1);
1313

1414

15-
$classifier = pipeline('text-classification', 'Xenova/distilbert-base-uncased-mnli');
16-
17-
$result = $classifier('I love you!, You frustrated my life');
15+
// $classifier = pipeline('text-classification', 'Xenova/distilbert-base-uncased-mnli');
1816

17+
$result = $classifier('I want to beat him to pulp', topK: -1);
1918

2019
dd($result);
21-
22-
23-

examples/pipelines/text-generation.php

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
//$generator = pipeline('text-generation', 'Xenova/gpt2');
1616
//$generator = pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
17-
$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
17+
//$generator = pipeline('text-generation', 'Xenova/TinyLlama-1.1B-Chat-v1.0');
18+
$generator = pipeline('text-generation', 'onnx-community/Llama-3.2-1B-Instruct', modelFilename: 'model_q4');
1819

1920
$streamer = TextStreamer::make()->shouldSkipPrompt();
2021

examples/pipelines/text2text-generation.php

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
ini_set('memory_limit', -1);
1111

12-
$generator = pipeline('text2text-generation', 'Xenova/LaMini-Flan-T5-783M');
13-
//$generator = pipeline('text2text-generation', 'Xenova/flan-t5-small', quantized: true);
12+
//$generator = pipeline('text2text-generation', 'Xenova/LaMini-Flan-T5-783M');
13+
$generator = pipeline('text2text-generation', 'Xenova/flan-t5-small', quantized: true);
1414

1515
$streamer = TextStreamer::make();
1616

@@ -22,4 +22,4 @@
2222
$output = $generator($query, streamer: $streamer, maxNewTokens: 256, doSample: true, repetitionPenalty: 1.1, temperature: 0.7);
2323

2424
//dd($output);
25-
dd('Done', timeUsage(), memoryUsage());
25+
dd('Done', timeUsage(), memoryUsage());

examples/pipelines/token-classification.php

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
ini_set('memory_limit', -1);
1313

14-
//$classifier = pipeline('token-classification', 'Xenova/bert-base-NER');
15-
$classifier = pipeline('token-classification', 'codewithkyrian/bert-english-uncased-finetuned-pos');
14+
$classifier = pipeline('token-classification', 'Xenova/bert-base-NER');
15+
//$classifier = pipeline('token-classification', 'codewithkyrian/bert-english-uncased-finetuned-pos');
1616

1717
$output = $classifier(
1818
'My name is Kyrian and I live in Nigeria',

src/Decoders/ByteFallback.php

+19-8
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,42 @@ public function __construct(array $config)
1818

1919
protected function decodeChain(array $tokens): array
2020
{
21-
$newTokens = [];
2221
$previousByteTokens = [];
22+
$newTokens = [];
2323

2424
foreach ($tokens as $token) {
2525
$bytes = null;
26+
27+
// Check if the token is of the form <0xXX>
2628
if (strlen($token) === 6 && str_starts_with($token, '<0x') && str_ends_with($token, '>')) {
29+
// Extract the hexadecimal value from the token
2730
$byte = hexdec(substr($token, 3, 2));
2831
if (!is_nan($byte)) {
2932
$bytes = $byte;
3033
}
3134
}
35+
3236
if ($bytes !== null) {
37+
// Add byte to previousByteTokens
3338
$previousByteTokens[] = $bytes;
3439
} else {
35-
if (count($previousByteTokens) > 0) {
36-
$string = $this->bytesToString($previousByteTokens);
37-
$newTokens[] = $string;
38-
$previousByteTokens = [];
40+
// If we have accumulated byte tokens, decode them to a string
41+
if (!empty($previousByteTokens)) {
42+
$string = pack('C*', ...$previousByteTokens); // Convert bytes back to string
43+
$newTokens[] = $string; // Add decoded string to newTokens
44+
$previousByteTokens = []; // Reset byte accumulator
3945
}
46+
// Add the non-byte token to newTokens
4047
$newTokens[] = $token;
4148
}
4249
}
43-
if (count($previousByteTokens) > 0) {
44-
$string = $this->bytesToString($previousByteTokens);
50+
51+
52+
// After the loop, if there are still byte tokens, decode them
53+
if (!empty($previousByteTokens)) {
54+
$string = pack('C*', ...$previousByteTokens); // Convert remaining bytes to string
4555
$newTokens[] = $string;
56+
$previousByteTokens = []; // Reset byte accumulator
4657
}
4758

4859
return $newTokens;
@@ -59,4 +70,4 @@ protected function bytesToString(array $bytes): string
5970
$binaryString = pack('C*', ...$bytes);
6071
return mb_convert_encoding($binaryString, 'ISO-8859-1');
6172
}
62-
}
73+
}

src/Decoders/ByteLevelDecoder.php

+4-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
namespace Codewithkyrian\Transformers\Decoders;
77

88
use Codewithkyrian\Transformers\Tokenizers\AddedToken;
9-
use Codewithkyrian\Transformers\Tokenizers\Tokenizer;
9+
use Codewithkyrian\Transformers\Tokenizers\TokenizerModel;
1010
use SplFixedArray;
1111

1212
class ByteLevelDecoder extends Decoder
@@ -287,7 +287,7 @@ public function convertTokensToString(array $tokens): string
287287

288288
$binaryString = pack('C*', ...$byteArray);
289289

290-
return mb_convert_encoding($binaryString, 'ISO-8859-1');
290+
return mb_convert_encoding($binaryString, 'UTF-8');
291291
}
292292

293293
protected function decodeChain(array $tokens): array
@@ -298,9 +298,7 @@ protected function decodeChain(array $tokens): array
298298
foreach ($tokens as $token) {
299299
// No need to check skip_special_tokens since the tokens are already filtered
300300

301-
$addedToken = array_filter($this->addedTokens, function (AddedToken $x) use ($token) {
302-
return $x->content === $token;
303-
});
301+
$addedToken = array_filter($this->addedTokens, fn (AddedToken $x) => $x->content === $token);
304302

305303
if (!empty($addedToken)) {
306304
if (!empty($currentSubText)) {
@@ -319,7 +317,6 @@ protected function decodeChain(array $tokens): array
319317
}
320318

321319
// TODO: add spaces_between_special_tokens and clean_up_tokenization_spaces options
322-
323320
return $subTexts;
324321
}
325-
}
322+
}

src/Decoders/CTCDecoder.php

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Codewithkyrian\Transformers\Decoders;
77

8-
use Codewithkyrian\Transformers\Tokenizers\Tokenizer;
8+
use Codewithkyrian\Transformers\Tokenizers\TokenizerModel;
99

1010
/**
1111
* The CTC (Connectionist Temporal Classification) decoder.
@@ -65,7 +65,7 @@ function convertTokensToString(array $tokens): string
6565
$text = implode('', $filteredTokens);
6666
if ($this->cleanup) {
6767
// Cleanup and replace delimiter token
68-
$text = trim(str_replace($this->wordDelimiterToken, ' ', Tokenizer::cleanUpTokenization($text)));
68+
$text = trim(str_replace($this->wordDelimiterToken, ' ', TokenizerModel::cleanUpTokenization($text)));
6969
}
7070

7171
return $text;
@@ -75,4 +75,4 @@ protected function decodeChain(array $tokens): array
7575
{
7676
return [$this->convertTokensToString($tokens)];
7777
}
78-
}
78+
}

src/Decoders/ReplaceDecoder.php

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Decoders;
76

87
class ReplaceDecoder extends Decoder
@@ -38,4 +37,4 @@ protected function decodeChain(array $tokens): array
3837
return $token;
3938
}, $tokens);
4039
}
41-
}
40+
}

src/Decoders/StripDecoder.php

+6-5
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,18 @@ protected function decodeChain(array $tokens): array
2828
return array_map(function ($token) {
2929
$startCut = 0;
3030
for ($i = 0; $i < $this->start; ++$i) {
31-
if ($token[$i] ?? null === $this->content) {
31+
$char = mb_substr($token, $i, 1);
32+
if ($char === $this->content) {
3233
$startCut = $i + 1;
3334
continue;
3435
} else {
3536
break;
3637
}
3738
}
3839

39-
$stopCut = strlen($token);
40+
$stopCut = mb_strlen($token);
4041
for ($i = 0; $i < $this->stop; ++$i) {
41-
$index = strlen($token) - $i - 1;
42+
$index = mb_strlen($token) - $i - 1;
4243
if ($token[$index] ?? null === $this->content) {
4344
$stopCut = $index;
4445
continue;
@@ -47,7 +48,7 @@ protected function decodeChain(array $tokens): array
4748
}
4849
}
4950

50-
return substr($token, $startCut, $stopCut - $startCut);
51+
return mb_substr($token, $startCut, $stopCut - $startCut);
5152
}, $tokens);
5253
}
53-
}
54+
}

src/Decoders/VitsDecoder.php

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\Decoders;
6+
7+
class VitsDecoder extends Decoder
8+
{
9+
10+
protected function decodeChain(array $tokens): array
11+
{
12+
$decoded = '';
13+
14+
for ($i = 1; $i < count($tokens); $i += 2) {
15+
$decoded .= $tokens[$i];
16+
}
17+
18+
return [$decoded];
19+
}
20+
}

src/Decoders/WordPieceDecoder.php

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Codewithkyrian\Transformers\Decoders;
77

8-
use Codewithkyrian\Transformers\Tokenizers\Tokenizer;
8+
use Codewithkyrian\Transformers\Tokenizers\TokenizerModel;
99

1010
class WordPieceDecoder extends Decoder
1111
{
@@ -31,12 +31,12 @@ protected function decodeChain(array $tokens): array
3131
}
3232
}
3333
if ($this->cleanup) {
34-
$token = Tokenizer::cleanUpTokenization($token);
34+
$token = TokenizerModel::cleanUpTokenization($token);
3535
}
3636

3737
$decodedTokens[] = $token;
3838
}
3939

4040
return $decodedTokens;
4141
}
42-
}
42+
}

src/Generation/Samplers/Sampler.php

+1-11
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,12 @@ abstract public function sample(Tensor $logits, int $index);
4646
*/
4747
public function getLogits(Tensor $logits, int $index): Tensor
4848
{
49-
// $vocabSize = $logits->shape()[$logits->ndim() - 1];
50-
51-
// $start = array_fill(0, $logits->ndim(), 0);
52-
// $size = array_fill(0, $logits->ndim(), 1);
53-
//
54-
// array_splice($start, -2, replacement: [$index, 0]);
55-
// array_splice($size, -2, replacement: [1, $vocabSize]);
56-
//
57-
// $logs = $logits->sliceWithBounds($start, $size);
5849
$logits = $logits->slice($index);
5950

6051
if ($this->generationConfig->temperature > 0) {
6152
$logits = $logits->multiply(1 / $this->generationConfig->temperature);
6253
}
6354

64-
// Remove all dimensions of 1, leaving a flat 1D array of vocab_size
6555
return $logits->squeeze();
6656
}
6757

@@ -116,4 +106,4 @@ public static function getSampler(GenerationConfig $generationConfig): Sampler
116106
return new GreedySampler($generationConfig);
117107
}
118108
}
119-
}
109+
}

src/Generation/Streamers/TextStreamer.php

+7-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ public function put(mixed $value): void
4141

4242
if ($this->skipPrompt && $this->nextTokensArePrompt) {
4343
$this->nextTokensArePrompt = false;
44-
$this->printedText = $this->tokenizer->decode($this->promptTokens, skipSpecialTokens: true);
45-
$this->printedLength = mb_strlen($this->printedText);
44+
// $this->printedText = $this->tokenizer->decode($this->promptTokens, skipSpecialTokens: true);
45+
// $this->printedLength = mb_strlen($this->printedText);
46+
// $this->lastDecodedCheckpointForToken = count($this->promptTokens) - 1;
47+
// $this->lastDecodedCheckpointForText = mb_strlen($this->printedText);
48+
// return;
49+
$prompt = $this->tokenizer->decode($this->promptTokens, skipSpecialTokens: true);
50+
$this->printedLength = mb_strlen($prompt);
4651
$this->lastDecodedCheckpointForToken = count($this->promptTokens) - 1;
47-
$this->lastDecodedCheckpointForText = mb_strlen($this->printedText);
4852
return;
4953
}
5054

@@ -90,4 +94,3 @@ public function end(): void
9094
$this->lastDecodedCheckpointForText = 0;
9195
}
9296
}
93-

0 commit comments

Comments
 (0)