diff --git a/src/Tokenizers/Tokenizer.php b/src/Tokenizers/Tokenizer.php index 37093b7..c840320 100644 --- a/src/Tokenizers/Tokenizer.php +++ b/src/Tokenizers/Tokenizer.php @@ -1,8 +1,7 @@ -continuingSubwordPrefix = $config['continuing_subword_prefix'] ?? null; if ($this->continuingSubwordPrefix == "") { $this->continuingSubwordPrefix = null; } + + $this->fuseUnk = $config['fuse_unk'] ?? false; } /** @@ -88,8 +94,7 @@ public static function load( string $revision, mixed $legacy, ?callable $onProgress = null - ): array - { + ): array { $tokenizerJson = Hub::getJson( $modelNameOrPath, fileName: 'tokenizer.json', @@ -227,7 +232,7 @@ public function __invoke(array $tokens): array { $ids = $this->encode($tokens); - if ($this->fuseUnk()) { + if ($this->fuseUnk) { $ids = $this->fuse($ids, $this->unkTokenId, $this->tokenToIds); } @@ -242,43 +247,32 @@ public function __invoke(array $tokens): array */ protected abstract function encode(array $tokens): array; - protected function fuseUnk(): bool - { - return $this->config['fuse_unk'] ?? false; - } - /** * Helper function to fuse consecutive values in an array equal to the specified value. * - * @param string[] $arr The input array + * @param array $arr The input array. * @param mixed $value The value to fuse on. - * @param array $mappings The mapping from input domain to value. - * @return array + * @param array $mapping The mapping from input domain to value. + * @return array The fused array. */ - private static function fuse(array $arr, mixed $value, array $mappings): array - { + protected function fuse(array $arr, mixed $value, array $mapping): array { $fused = []; - $fusedIds = []; - $fusedLength = 0; + $i = 0; + $length = count($arr); - foreach ($arr as $i => $v) { - if ($v === $value) { - $fusedLength++; - } else { - if ($fusedLength > 0) { - $fused[] = $mappings[$value]; - $fusedIds[] = $value; - $fusedLength = 0; - } + while ($i < $length) { + $fused[] = $arr[$i]; - $fused[] = $v; - $fusedIds[] = $i; + // Check if the current element's mapping is not equal to the specified value + if (($mapping[$arr[$i]] ?? $value) !== $value) { + $i++; + continue; } - } - if ($fusedLength > 0) { - $fused[] = $mappings[$value]; - $fusedIds[] = $value; + // Skip consecutive elements equal to the specified value + while ($i < $length && ($mapping[$arr[$i]] ?? $value) === $value) { + $i++; + } } return $fused; @@ -343,7 +337,6 @@ public function convertTokensToIds(array $tokens): array $ids[] = $this->tokenToIds[$token] ?? $this->unkTokenId; } - return $ids; } @@ -363,4 +356,4 @@ public function convertIdsToTokens(array $ids): array return $tokens; } -} \ No newline at end of file +}