Skip to content

Commit

Permalink
fix: fuse function not combining unknown token ids correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Sep 14, 2024
1 parent bee47e0 commit 3008013
Showing 1 changed file with 27 additions and 34 deletions.
61 changes: 27 additions & 34 deletions src/Tokenizers/Tokenizer.php
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
<?php /** @noinspection PhpUnreachableStatementInspection */
<?php

declare(strict_types=1);


namespace Codewithkyrian\Transformers\Tokenizers;

use ArrayObject;
Expand Down Expand Up @@ -44,12 +43,19 @@ abstract class Tokenizer
*/
protected ?string $unkToken = null;

/**
* Whether to fuse the unknown token into the vocabulary.
*/
protected bool $fuseUnk = false;

public function __construct(protected array $config)
{
$this->continuingSubwordPrefix = $config['continuing_subword_prefix'] ?? null;
if ($this->continuingSubwordPrefix == "") {
$this->continuingSubwordPrefix = null;
}

$this->fuseUnk = $config['fuse_unk'] ?? false;
}

/**
Expand Down Expand Up @@ -88,8 +94,7 @@ public static function load(
string $revision,
mixed $legacy,
?callable $onProgress = null
): array
{
): array {
$tokenizerJson = Hub::getJson(
$modelNameOrPath,
fileName: 'tokenizer.json',
Expand Down Expand Up @@ -227,7 +232,7 @@ public function __invoke(array $tokens): array
{
$ids = $this->encode($tokens);

if ($this->fuseUnk()) {
if ($this->fuseUnk) {
$ids = $this->fuse($ids, $this->unkTokenId, $this->tokenToIds);
}

Expand All @@ -242,43 +247,32 @@ public function __invoke(array $tokens): array
*/
protected abstract function encode(array $tokens): array;

protected function fuseUnk(): bool
{
return $this->config['fuse_unk'] ?? false;
}

/**
* Helper function to fuse consecutive values in an array equal to the specified value.
*
* @param string[] $arr The input array
* @param array $arr The input array.
* @param mixed $value The value to fuse on.
* @param array<string, mixed> $mappings The mapping from input domain to value.
* @return array
* @param array $mapping The mapping from input domain to value.
* @return array The fused array.
*/
private static function fuse(array $arr, mixed $value, array $mappings): array
{
protected function fuse(array $arr, mixed $value, array $mapping): array {
$fused = [];
$fusedIds = [];
$fusedLength = 0;
$i = 0;
$length = count($arr);

foreach ($arr as $i => $v) {
if ($v === $value) {
$fusedLength++;
} else {
if ($fusedLength > 0) {
$fused[] = $mappings[$value];
$fusedIds[] = $value;
$fusedLength = 0;
}
while ($i < $length) {
$fused[] = $arr[$i];

$fused[] = $v;
$fusedIds[] = $i;
// Check if the current element's mapping is not equal to the specified value
if (($mapping[$arr[$i]] ?? $value) !== $value) {
$i++;
continue;
}
}

if ($fusedLength > 0) {
$fused[] = $mappings[$value];
$fusedIds[] = $value;
// Skip consecutive elements equal to the specified value
while ($i < $length && ($mapping[$arr[$i]] ?? $value) === $value) {
$i++;
}
}

return $fused;
Expand Down Expand Up @@ -343,7 +337,6 @@ public function convertTokensToIds(array $tokens): array
$ids[] = $this->tokenToIds[$token] ?? $this->unkTokenId;
}


return $ids;
}

Expand All @@ -363,4 +356,4 @@ public function convertIdsToTokens(array $ids): array

return $tokens;
}
}
}

0 comments on commit 3008013

Please sign in to comment.