Skip to content

Commit

Permalink
feat: Add support for PostProcessor Sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Sep 27, 2024
1 parent 0564dd1 commit 2cf18cc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/PostProcessors/PostProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public static function fromConfig(?array $config): ?self
'ByteLevel' => new ByteLevelPostProcessor($config),
'TemplateProcessing' => new TemplateProcessing($config),
'RobertaProcessing' => new RobertaProcessing($config),
'Sequence' => new PostProcessorSequence($config),
default => throw new \InvalidArgumentException("Unknown post-processor type {$config['type']}"),
};
}
Expand Down
68 changes: 68 additions & 0 deletions src/PostProcessors/PostProcessorSequence.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<?php

declare(strict_types=1);

namespace Codewithkyrian\Transformers\PostProcessors;

/**
* A post-processor that applies multiple post-processors in sequence.
*/
class PostProcessorSequence extends PostProcessor
{

/**
* List of post-processors to apply.
*/
protected array $processors;

/**
* Creates a new instance of PostProcessorSequence.
*
* @param array $config The configuration array.
* - 'processors' (array): The list of post-processors to apply.
*/
public function __construct(array $config)
{
parent::__construct($config);

$this->processors = array_map(
fn ($processorConfig) => PostProcessor::fromConfig($processorConfig),
$config['processors']
);
}

/**
* Post-process the given tokens.
*
* @param array $tokens The list of tokens for the first sequence.
* @param string[]|null $tokenPair The input tokens for the second sequence in a pair.
* * @param bool $addSpecialTokens Whether to add the special tokens associated with the corresponding model.
*
* @return PostProcessedOutput An array containing the post-processed tokens and token_type_ids.
*/
public function postProcess(array $tokens, ?array $tokenPair = null, bool $addSpecialTokens = true): PostProcessedOutput
{
$tokenTypeIds = null;

foreach ($this->processors as $processor) {
if ($processor instanceof ByteLevelPostProcessor) {
// Special case where we need to pass the tokens_pair to the post-processor
$output = $processor->postProcess($tokens);
$tokens = $output->tokens;

if ($tokenPair !== null) {
$pairOutput = $processor->postProcess($tokenPair);
$tokenPair = $pairOutput->tokens;
}
} else {
$output = $processor->postProcess($tokens, $tokenPair, $addSpecialTokens);
$tokens = $output->tokens;
$tokenTypeIds = $output->tokenTypeIds;
}
}

return new PostProcessedOutput($tokens, $tokenTypeIds);
}
}

?>
2 changes: 1 addition & 1 deletion src/Utils/AutoConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AutoConfig implements ArrayAccess

protected array $architectures = [];

public int $padTokenId;
public int|array $padTokenId;

protected int $vocabSize;

Expand Down
4 changes: 2 additions & 2 deletions src/Utils/GenerationConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class GenerationConfig implements \ArrayAccess
/** @var bool Whether or not to return a `ModelOutput` instead of a plain tuple. */
public bool $return_dict_in_generate;

/** @var int|null The id of the *padding* token. */
public ?int $pad_token_id;
/** @var int|int[]|null The id of the *padding* token. */
public int|array|null $pad_token_id;

/** @var int|null The id of the *beginning-of-sequence* token. */
public ?int $bos_token_id;
Expand Down

0 comments on commit 2cf18cc

Please sign in to comment.