Skip to content

Commit b333162

Browse files
Bump OnnxRuntime PHP to 0.2.0 and improve download and install command interfaces
1 parent 9454f25 commit b333162

16 files changed

+161
-78
lines changed

composer.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"require": {
1616
"php": "^8.1",
1717
"guzzlehttp/guzzle": "^7.0",
18-
"ankane/onnxruntime": "^0.1.8",
18+
"ankane/onnxruntime": "^0.2.0",
1919
"ext-gd": "*",
2020
"rindow/rindow-math-matrix": "^1.2",
2121
"spatie/fork": "^1.2"

examples/pipelines/feature-extraction.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
$embeddings = $extractor('The quick brown fox jumps over the lazy dog.', normalize: true, pooling: 'mean');
1616

17-
dd(memoryUsage(), timeUsage(), $embeddings[0]);
17+
dd(memoryUsage(), timeUsage(milliseconds: true), count($embeddings[0]));

examples/pipelines/text-classification.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
$classifier = pipeline('text-classification', 'Xenova/distilbert-base-uncased-mnli');
1616

17-
$result = $classifier('A person is eating, there is no food left');
17+
$result = $classifier('I love you!, You frustrated my life');
1818

1919

2020
dd($result);

src/Commands/DownloadModelCommand.php

+7-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Codewithkyrian\Transformers\Transformers;
1515
use Symfony\Component\Console\Attribute\AsCommand;
1616
use Symfony\Component\Console\Command\Command;
17+
use Symfony\Component\Console\Helper\ProgressBar;
1718
use Symfony\Component\Console\Input\InputArgument;
1819
use Symfony\Component\Console\Input\InputInterface;
1920
use Symfony\Component\Console\Input\InputOption;
@@ -54,7 +55,7 @@ protected function configure(): void
5455

5556
protected function execute(InputInterface $input, OutputInterface $output): int
5657
{
57-
$output->writeln('Downloading model...');
58+
$output->writeln('Initializing download...');
5859

5960
$model = $input->getArgument('model');
6061
$cacheDir = $input->getOption('cache-dir');
@@ -69,18 +70,18 @@ protected function execute(InputInterface $input, OutputInterface $output): int
6970
$task = $task ? Task::tryFrom($task) : null;
7071

7172
if ($task != null) {
72-
pipeline($task, $model);
73+
pipeline($task, $model, output: $output);
7374
} else {
74-
AutoTokenizer::fromPretrained($model, $quantized);
75-
AutoModel::fromPretrained($model, $quantized);
75+
AutoTokenizer::fromPretrained($model, quantized: $quantized, output: $output);
76+
AutoModel::fromPretrained($model, $quantized, output: $output);
7677
}
7778

7879

79-
$output->writeln('✔ Model downloaded successfully.');
80+
$output->writeln('✔ Model files downloaded successfully.');
8081

8182
return Command::SUCCESS;
8283
} catch (\Exception $e) {
83-
$output->writeln('An error occurred while downloading the model: ' . $e->getMessage());
84+
$output->writeln(''. $e->getMessage());
8485
return Command::FAILURE;
8586
}
8687
}

src/Commands/InitCommand.php

+57-10
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
namespace Codewithkyrian\Transformers\Commands;
66

77
use Codewithkyrian\Transformers\Transformers;
8+
use GuzzleHttp\Client;
9+
use GuzzleHttp\Exception\GuzzleException;
10+
use GuzzleHttp\Psr7\Utils;
811
use OnnxRuntime\Exception;
912
use OnnxRuntime\Vendor;
13+
use PharData;
1014
use Symfony\Component\Console\Attribute\AsCommand;
1115
use Symfony\Component\Console\Command\Command;
16+
use Symfony\Component\Console\Helper\ProgressBar;
1217
use Symfony\Component\Console\Input\InputInterface;
1318
use Symfony\Component\Console\Input\InputOption;
1419
use Symfony\Component\Console\Output\OutputInterface;
@@ -52,48 +57,90 @@ protected function execute(InputInterface $input, OutputInterface $output): int
5257

5358
ensureDirectory(Transformers::$cacheDir);
5459

55-
echo "◌ Downloading ONNX Runtime...\n";
60+
echo "✔ Initializing Transformers...\n";
5661

5762
$file = Transformers::platform('file');
5863
$ext = Transformers::platform('ext');
5964

6065
$urlTemplate = "https://github.com/microsoft/onnxruntime/releases/download/v{{version}}/$file.$ext";
61-
$url = str_replace('{{version}}', Transformers::ONNX_VERSION, $urlTemplate);
66+
$url = str_replace('{{version}}', Vendor::VERSION, $urlTemplate);
6267

63-
$contents = @file_get_contents($url);
68+
$client = new Client();
69+
$tempDest = tempnam(sys_get_temp_dir(), 'onnxruntime') . '.' . $ext;
6470

65-
if (!$contents) {
66-
throw new \Exception("Something went wrong");
67-
}
71+
ProgressBar::setFormatDefinition('hub', '%filename% : [%bar%] %percent:3s%%');
72+
73+
$progressBar = new ProgressBar($output, 100);
74+
$progressBar->setFormat('hub');
75+
$progressBar->setBarCharacter('<fg=green>•</>');
76+
$progressBar->setEmptyBarCharacter("<fg=red>⚬</>");
77+
$progressBar->setProgressCharacter('<fg=green>➤</>');
78+
$progressBar->setMessage("✔ Downloading Libraries", 'filename');
79+
80+
$client->get($url, ['sink' => $tempDest, 'progress' => self::onProgress($progressBar)]);
81+
82+
$contents = @file_get_contents($tempDest);
6883

6984
$checksum = hash('sha256', $contents);
85+
7086
if ($checksum != Transformers::platform('checksum')) {
7187
throw new Exception("Bad checksum: $checksum");
7288
}
7389

74-
$tempDest = tempnam(sys_get_temp_dir(), 'onnxruntime') . '.' . $ext;
75-
76-
file_put_contents($tempDest, $contents);
90+
$archive = new PharData($tempDest);
7791

78-
$archive = new \PharData($tempDest);
7992
if ($ext != 'zip') {
8093
$archive = $archive->decompress();
8194
}
8295

8396
$archive->extractTo(Transformers::$cacheDir);
8497

98+
echo "\n"; // New line to since Symphony ProgressBar doesn't add a new line.
8599
$output->writeln('✔ Initialized Transformers successfully.');
86100

87101
$this->askToStar($input, $output);
88102

89103
return Command::SUCCESS;
104+
} catch (GuzzleException $e) {
105+
$output->writeln($e->getMessage());
106+
107+
return Command::FAILURE;
90108
} catch (Exception $e) {
91109
$output->writeln($e->getMessage());
92110

93111
return Command::FAILURE;
94112
}
95113
}
96114

115+
/**
116+
* @param resource $stream
117+
* @return string
118+
*/
119+
public function calculateHash($stream): string
120+
{
121+
$ctx = hash_init('sha256');
122+
123+
while (!feof($stream)) {
124+
$buffer = fread($stream, 8192); // Read in 8KB chunks
125+
hash_update($ctx, $buffer);
126+
}
127+
128+
$hash = hash_final($ctx);
129+
fclose($stream);
130+
131+
return $hash;
132+
}
133+
134+
private static function onProgress(ProgressBar $progressBar): callable
135+
{
136+
return function ($totalDownload, $downloadedBytes) use ($progressBar) {
137+
if ($totalDownload == 0) return;
138+
139+
$percent = round(($downloadedBytes / $totalDownload) * 100, 2);
140+
$progressBar->setProgress((int)$percent);
141+
};
142+
}
143+
97144

98145
protected function askToStar(InputInterface $input, OutputInterface $output): void
99146
{

src/Models/Auto/PretrainedMixin.php

+7-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use Codewithkyrian\Transformers\Models\ModelArchitecture;
1010
use Codewithkyrian\Transformers\Models\Pretrained\PreTrainedModel;
1111
use Codewithkyrian\Transformers\Utils\AutoConfig;
12+
use Symfony\Component\Console\Output\OutputInterface;
1213

1314
/**
1415
* Base class of all AutoModels. Contains the `from_pretrained` function
@@ -46,9 +47,10 @@ public static function fromPretrained(
4647
?string $cacheDir = null,
4748
string $revision = 'main',
4849
?string $modelFilename = null,
50+
?OutputInterface $output = null
4951
): PreTrainedModel
5052
{
51-
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision);
53+
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $output);
5254

5355
foreach (static::MODEL_CLASS_MAPPINGS as $modelClassMapping) {
5456
$modelClass = $modelClassMapping[$config->modelType] ?? null;
@@ -65,7 +67,8 @@ public static function fromPretrained(
6567
cacheDir: $cacheDir,
6668
revision: $revision,
6769
modelFilename: $modelFilename,
68-
modelArchitecture: $modelArchitecture
70+
modelArchitecture: $modelArchitecture,
71+
output: $output
6972
);
7073
}
7174

@@ -78,7 +81,8 @@ public static function fromPretrained(
7881
config: $config,
7982
cacheDir: $cacheDir,
8083
revision: $revision,
81-
modelFilename: $modelFilename
84+
modelFilename: $modelFilename,
85+
output: $output
8286
);
8387
} else {
8488
throw UnsupportedModelTypeException::make($config->modelType);

src/Models/ModelArchitecture.php

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
namespace Codewithkyrian\Transformers\Models;
66

7+
use Codewithkyrian\Transformers\Exceptions\MissingModelInputException;
8+
use Codewithkyrian\Transformers\Exceptions\ModelExecutionException;
79
use Codewithkyrian\Transformers\Models\Pretrained\PreTrainedModel;
810
use Codewithkyrian\Transformers\Utils\GenerationConfig;
911
use Codewithkyrian\Transformers\Utils\Tensor;
@@ -79,11 +81,11 @@ protected function encoderForward(PreTrainedModel $model, array $modelInputs): a
7981
{
8082
$encoderFeeds = [];
8183

82-
foreach ($model->session->inputs as ['name' => $inputName]) {
84+
foreach ($model->session->inputs() as ['name' => $inputName]) {
8385
$encoderFeeds[$inputName] = $modelInputs[$inputName];
8486
}
8587

86-
$hasTokenTypeIds = in_array('token_type_ids', array_column($model->session->inputs, 'name'));
88+
$hasTokenTypeIds = in_array('token_type_ids', array_column($model->session->inputs(), 'name'));
8789

8890
if ($hasTokenTypeIds) {
8991
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
@@ -196,6 +198,7 @@ protected function decoderUpdatebeam(array &$beam, int $newTokenId): void
196198
* @param PreTrainedModel $model The model to use for the forward pass.
197199
* @param array $modelInputs The inputs to the model.
198200
* @return array The output of the forward pass.
201+
* @throws MissingModelInputException|ModelExecutionException
199202
*/
200203
protected function decoderForward(PreTrainedModel $model, array $modelInputs): array
201204
{
@@ -209,7 +212,7 @@ protected function decoderForward(PreTrainedModel $model, array $modelInputs): a
209212

210213
$useCacheBranch = !!$pastKeyValues;
211214

212-
$inputNames = array_column($model->session->inputs, 'name');
215+
$inputNames = array_column($model->session->inputs(), 'name');
213216

214217
if (in_array('use_cache_branch', $inputNames)) {
215218
$decoderFeeds['use_cache_branch'] = new Tensor([$useCacheBranch], shape: [1]);
@@ -345,7 +348,7 @@ protected function seq2seqForward(PreTrainedModel $model, array $modelInputs): a
345348

346349
$useCacheBranch = !!$pastKeyValues;
347350

348-
$inputNames = array_column($model->decoderMergedSession->inputs, 'name');
351+
$inputNames = array_column($model->decoderMergedSession->inputs(), 'name');
349352

350353

351354
if (in_array('use_cache_branch', $inputNames)) {

src/Models/Pretrained/PreTrainedModel.php

+10-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
use Codewithkyrian\Transformers\Utils\Tensor;
3232
use Exception;
3333
use OnnxRuntime\InferenceSession;
34+
use Symfony\Component\Console\Output\OutputInterface;
3435
use function Codewithkyrian\Transformers\Utils\array_some;
3536

3637
/**
@@ -83,22 +84,23 @@ public static function fromPretrained(
8384
?string $token = null,
8485
string $revision = 'main',
8586
?string $modelFilename = null,
86-
ModelArchitecture $modelArchitecture = ModelArchitecture::EncoderOnly
87+
ModelArchitecture $modelArchitecture = ModelArchitecture::EncoderOnly,
88+
?OutputInterface $output = null
8789
): self
8890
{
8991
if (is_array($config)) {
90-
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision);
92+
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $output);
9193
}
9294

9395

9496
switch ($modelArchitecture) {
9597
case ModelArchitecture::DecoderOnly:
9698
{
9799
$session = self::constructSession(modelNameOrPath: $modelNameOrPath,
98-
fileName: $modelFilename ?? 'decoder_model_merged', cacheDir: $cacheDir, revision: $revision);
100+
fileName: $modelFilename ?? 'decoder_model_merged', cacheDir: $cacheDir, revision: $revision, output: $output);
99101

100102
$generatorConfigArr = Hub::getJson(pathOrRepoID: $modelNameOrPath, fileName: 'generation_config.json',
101-
cacheDir: $cacheDir, revision: $revision, fatal: false);
103+
cacheDir: $cacheDir, revision: $revision, fatal: false, output: $output);
102104

103105
$generatorConfig = new GenerationConfig($generatorConfigArr);
104106

@@ -248,12 +250,13 @@ public static function constructSession(
248250
string $revision = 'main',
249251
string $subFolder = '',
250252
bool $fatal = true,
253+
?OutputInterface $output = null,
251254
...$sessionOptions
252255
): ?InferenceSession
253256
{
254257
$modelFileName = sprintf('onnx/%s%s.onnx', $fileName, $quantized ? '_quantized' : '');
255258

256-
$file = Hub::getFile($modelNameOrPath, $modelFileName, $cacheDir, $revision, $subFolder, $fatal);
259+
$file = Hub::getFile($modelNameOrPath, $modelFileName, $cacheDir, $revision, $subFolder, $fatal, null, $output);
257260

258261
if ($file === null) return null;
259262

@@ -310,11 +313,11 @@ public function validateInputs(array $inputNames, array $inputs): array
310313
public function runSession(InferenceSession $session, array $inputs): array
311314
{
312315
try {
313-
$inputNames = array_column($session->inputs, 'name');
316+
$inputNames = array_column($session->inputs(), 'name');
314317

315318
$inputs = $this->validateInputs($inputNames, $inputs);
316319

317-
$outputNames = array_column($session->outputs, 'name');
320+
$outputNames = array_column($session->outputs(), 'name');
318321

319322
$outputs = $session->run($outputNames, $inputs);
320323

src/Pipelines/Pipeline.php

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use Codewithkyrian\Transformers\Models\Pretrained\PreTrainedModel;
99
use Codewithkyrian\Transformers\PretrainedTokenizers\AutoTokenizer;
1010
use Codewithkyrian\Transformers\PretrainedTokenizers\PretrainedTokenizer;
11+
use Symfony\Component\Console\Output\OutputInterface;
1112

1213
class Pipeline
1314
{
@@ -60,6 +61,7 @@ function pipeline(
6061
?array $config = null,
6162
?string $cacheDir = null,
6263
string $revision = 'main',
64+
?OutputInterface $output = null
6365
): Pipeline
6466
{
6567
if (is_string($task)) {
@@ -73,9 +75,9 @@ function pipeline(
7375

7476
$modelName ??= $task->defaultModelName();
7577

76-
$model = $task->pretrainedModel($modelName, $quantized, $config, $cacheDir, $revision);
78+
$model = $task->pretrainedModel($modelName, $quantized, $config, $cacheDir, $revision, $output);
7779

78-
$tokenizer = AutoTokenizer::fromPretrained($modelName, $quantized, $config, $cacheDir, $revision);
80+
$tokenizer = AutoTokenizer::fromPretrained($modelName, $quantized, $config, $cacheDir, $revision, $output);
7981

8082
return $task->getPipeline($model, $tokenizer);
8183
}

0 commit comments

Comments
 (0)