Skip to content

Commit

Permalink
Fix download command bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Mar 14, 2024
1 parent 625223f commit ab9326d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ Explanation of Arguments:
models are
smaller and faster, but may have slightly lower accuracy. Use -q as a shortcut in the command.

> [!DANGER]
> Remember to add your cache directory to your `.gitignore` file to avoid committing the downloaded models to your git
> repository.
## Supported tasks/models

This package is a WIP, but here's a list of tasks and architectures currently tested and supported by Transformers PHP.
Expand Down
14 changes: 8 additions & 6 deletions src/Commands/DownloadModelCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use Codewithkyrian\Transformers\Models\Auto\AutoModelForCausalLM;
use Codewithkyrian\Transformers\Models\Auto\AutoModelForSeq2SeqLM;
use Codewithkyrian\Transformers\Models\Auto\AutoModelForSequenceClassification;
use Codewithkyrian\Transformers\Transformers;
use Symfony\Component\Console\Attribute\AsCommand;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
Expand All @@ -35,8 +36,7 @@ protected function configure(): void
'cache-dir',
'c',
InputOption::VALUE_OPTIONAL,
'The directory to cache the model in.',
'models'
'The directory to cache the model in.'
);

$this->addOption(
Expand All @@ -51,6 +51,8 @@ protected function configure(): void

protected function execute(InputInterface $input, OutputInterface $output): int
{
Transformers::configure();

$output->writeln('✔ Downloading model...');

$model = $input->getArgument('model');
Expand All @@ -63,10 +65,10 @@ protected function execute(InputInterface $input, OutputInterface $output): int
try {
// TODO: Verify the tasks and corresponding AutoModel classes
$model = match ($task) {
'text-generation' => AutoModelForCausalLM::fromPretrained($model, $quantized, $cacheDir),
'text-classification', 'sentiment-analysis' => AutoModelForSequenceClassification::fromPretrained($model, $quantized, $cacheDir),
'translation' => AutoModelForSeq2SeqLM::fromPretrained($model, $quantized, $cacheDir),
default => AutoModel::fromPretrained($model, $quantized, $cacheDir),
'text-generation' => AutoModelForCausalLM::fromPretrained($model, $quantized, cacheDir: $cacheDir),
'text-classification', 'sentiment-analysis' => AutoModelForSequenceClassification::fromPretrained($model, $quantized, cacheDir: $cacheDir),
'translation' => AutoModelForSeq2SeqLM::fromPretrained($model, $quantized, cacheDir: $cacheDir),
default => AutoModel::fromPretrained($model, $quantized, cacheDir: $cacheDir),
};

$output->writeln('✔ Model downloaded successfully.');
Expand Down
1 change: 1 addition & 0 deletions src/Utils/Hub.php
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public static function getFile(
{
# Local cache and file paths
$cacheDir ??= Transformers::$cacheDir;

$filePath = joinPaths($cacheDir, $pathOrRepoID, $subFolder, $fileName);

# Check if file already exists
Expand Down

0 comments on commit ab9326d

Please sign in to comment.