Skip to content

Commit 715e2f7

Browse files
Refactor arguments and fix wrong attention mask type in generation
1 parent 00be346 commit 715e2f7

File tree

4 files changed

+52
-21
lines changed

4 files changed

+52
-21
lines changed

examples/pipelines/text-generation.php

+23-17
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,34 @@
1010
use function Codewithkyrian\Transformers\Utils\timeUsage;
1111

1212
ini_set('memory_limit', -1);
13-
14-
//$generator = pipeline('text-generation', 'Xenova/gpt2');
15-
//
16-
//$streamer = StdOutStreamer::make($generator->tokenizer);
1713
//
18-
//$output = $generator('The Black man worked as a',
19-
// streamer: $streamer,
20-
// maxNewTokens: 128,
21-
// doSample: true,
22-
// temperature: 0.7,
23-
// repetitionPenalty: 1.3,
24-
// earlyStopping: true
25-
//);
14+
$generator = pipeline('text-generation', 'Xenova/gpt2');
2615

27-
$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
2816
$streamer = StdOutStreamer::make($generator->tokenizer);
2917

30-
$output = $generator(
31-
'def fib(n):',
18+
$messages = [
19+
['role' => 'user', 'content' => 'Hello!'],
20+
['role' => 'assistant', 'content' => 'Hi! How are you?'],
21+
['role' => 'user', 'content' => 'I am doing great. What about you?'],
22+
];
23+
24+
$output = $generator("I love going to school but I don't",
3225
streamer: $streamer,
33-
maxNewTokens: 50,
34-
doSample: true
26+
maxNewTokens: 128,
27+
doSample: true,
28+
temperature: 0.7,
29+
repetitionPenalty: 1.3,
30+
earlyStopping: true
3531
);
3632

33+
//$generator = pipeline('text-generation', 'Xenova/codegen-350M-mono');
34+
//$streamer = StdOutStreamer::make($generator->tokenizer);
35+
//
36+
//$output = $generator(
37+
// 'def fib(n):',
38+
// streamer: $streamer,
39+
// maxNewTokens: 100,
40+
// doSample: true
41+
//);
42+
3743
dd("done", timeUsage(), memoryUsage());

src/Models/Pretrained/PreTrainedModel.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ public function generate(
684684
Tensor $inputs,
685685
?GenerationConfig $generationConfig = null,
686686
?LogitsProcessorList $logitsProcessor = null,
687-
array $inputsAttentionMask = null,
687+
Tensor $inputsAttentionMask = null,
688688
?Streamer $streamer = null,
689689
): array
690690
{

src/Pipelines/TextGenerationPipeline.php

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public function __invoke(array|string $texts, ...$args): array
6464
$snakeCasedArgs[$this->camelCaseToSnakeCase($key)] = $value;
6565
}
6666

67-
$generateKwargs = new GenerationConfig($snakeCasedArgs);
67+
$generationConfig = new GenerationConfig($snakeCasedArgs);
6868

6969
$isBatched = is_array($texts);
7070
if (!$isBatched) {
@@ -75,14 +75,14 @@ public function __invoke(array|string $texts, ...$args): array
7575
$addSpecialTokens = $this->model->config['add_special_tokens'] ?? false;
7676

7777
$this->tokenizer->paddingSide = 'left';
78-
['input_ids' => $inputIds, 'attention_mask' => $attentionMask] = $this->tokenizer->__invoke(
78+
['input_ids' => $inputIds, 'attention_mask' => $attentionMask] = $this->tokenizer->tokenize(
7979
$texts,
8080
padding: true,
8181
addSpecialTokens: $addSpecialTokens,
8282
truncation: true
8383
);
8484

85-
$outputTokenIds = $this->model->generate($inputIds, generationConfig: $generateKwargs, streamer: $streamer);
85+
$outputTokenIds = $this->model->generate($inputIds, generationConfig: $generationConfig, streamer: $streamer);
8686

8787
$decoded = $this->tokenizer->batchDecode($outputTokenIds, skipSpecialTokens: true);
8888

src/PretrainedTokenizers/PretrainedTokenizer.php

+25
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,31 @@ function ($key) {
368368
return $result;
369369
}
370370

371+
/**
372+
* Tokenize the given text(s).
373+
*
374+
* @param string|array $text The text to tokenize.
375+
* @param string|array|null $textPair Optional second sequence to be encoded. If set, must be the same type as text.
376+
* @param bool|string $padding Whether to pad the input sequences.
377+
* @param bool $addSpecialTokens Whether to add the special tokens associated with the corresponding model.
378+
* @param bool $truncation Whether to truncate the input sequences.
379+
* @param int|null $maxLength Maximum length of the returned list and optionally padding length.
380+
*
381+
* @return array{input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor|null}
382+
*/
383+
public function tokenize(
384+
string|array $text,
385+
string|array|null $textPair = null,
386+
bool|string $padding = false,
387+
bool $addSpecialTokens = true,
388+
bool $truncation = false,
389+
?int $maxLength = null,
390+
): array
391+
{
392+
return $this->__invoke($text, $textPair, $padding, $addSpecialTokens, $truncation, $maxLength);
393+
}
394+
395+
371396
/**
372397
* Encodes a single text using the preprocessor pipeline of the tokenizer.
373398
*

0 commit comments

Comments
 (0)