From 2cf18ccd83ad84ad8389a7f8c0198252d523f3f6 Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Fri, 27 Sep 2024 20:35:25 +0100 Subject: [PATCH] feat: Add support for PostProcessor Sequence --- src/PostProcessors/PostProcessor.php | 1 + src/PostProcessors/PostProcessorSequence.php | 68 ++++++++++++++++++++ src/Utils/AutoConfig.php | 2 +- src/Utils/GenerationConfig.php | 4 +- 4 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 src/PostProcessors/PostProcessorSequence.php diff --git a/src/PostProcessors/PostProcessor.php b/src/PostProcessors/PostProcessor.php index 7c596c0..123ae26 100644 --- a/src/PostProcessors/PostProcessor.php +++ b/src/PostProcessors/PostProcessor.php @@ -28,6 +28,7 @@ public static function fromConfig(?array $config): ?self 'ByteLevel' => new ByteLevelPostProcessor($config), 'TemplateProcessing' => new TemplateProcessing($config), 'RobertaProcessing' => new RobertaProcessing($config), + 'Sequence' => new PostProcessorSequence($config), default => throw new \InvalidArgumentException("Unknown post-processor type {$config['type']}"), }; } diff --git a/src/PostProcessors/PostProcessorSequence.php b/src/PostProcessors/PostProcessorSequence.php new file mode 100644 index 0000000..45c48c6 --- /dev/null +++ b/src/PostProcessors/PostProcessorSequence.php @@ -0,0 +1,68 @@ +processors = array_map( + fn ($processorConfig) => PostProcessor::fromConfig($processorConfig), + $config['processors'] + ); + } + + /** + * Post-process the given tokens. + * + * @param array $tokens The list of tokens for the first sequence. + * @param string[]|null $tokenPair The input tokens for the second sequence in a pair. + * * @param bool $addSpecialTokens Whether to add the special tokens associated with the corresponding model. + * + * @return PostProcessedOutput An array containing the post-processed tokens and token_type_ids. + */ + public function postProcess(array $tokens, ?array $tokenPair = null, bool $addSpecialTokens = true): PostProcessedOutput + { + $tokenTypeIds = null; + + foreach ($this->processors as $processor) { + if ($processor instanceof ByteLevelPostProcessor) { + // Special case where we need to pass the tokens_pair to the post-processor + $output = $processor->postProcess($tokens); + $tokens = $output->tokens; + + if ($tokenPair !== null) { + $pairOutput = $processor->postProcess($tokenPair); + $tokenPair = $pairOutput->tokens; + } + } else { + $output = $processor->postProcess($tokens, $tokenPair, $addSpecialTokens); + $tokens = $output->tokens; + $tokenTypeIds = $output->tokenTypeIds; + } + } + + return new PostProcessedOutput($tokens, $tokenTypeIds); + } +} + +?> \ No newline at end of file diff --git a/src/Utils/AutoConfig.php b/src/Utils/AutoConfig.php index d346be3..6a5d869 100644 --- a/src/Utils/AutoConfig.php +++ b/src/Utils/AutoConfig.php @@ -16,7 +16,7 @@ class AutoConfig implements ArrayAccess protected array $architectures = []; - public int $padTokenId; + public int|array $padTokenId; protected int $vocabSize; diff --git a/src/Utils/GenerationConfig.php b/src/Utils/GenerationConfig.php index 2f7b7ac..0a23c99 100644 --- a/src/Utils/GenerationConfig.php +++ b/src/Utils/GenerationConfig.php @@ -124,8 +124,8 @@ class GenerationConfig implements \ArrayAccess /** @var bool Whether or not to return a `ModelOutput` instead of a plain tuple. */ public bool $return_dict_in_generate; - /** @var int|null The id of the *padding* token. */ - public ?int $pad_token_id; + /** @var int|int[]|null The id of the *padding* token. */ + public int|array|null $pad_token_id; /** @var int|null The id of the *beginning-of-sequence* token. */ public ?int $bos_token_id;