-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from CodeWithKyrian/add-zero-shot-object-detec…
…tion-pipeline Add Zero Shot Object Detection Pipeline and OwlVit models
- Loading branch information
Showing
21 changed files
with
303 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Codewithkyrian\Transformers\Pipelines; | ||
|
||
use function Codewithkyrian\Transformers\Utils\memoryUsage; | ||
use function Codewithkyrian\Transformers\Utils\timeUsage; | ||
|
||
require_once './bootstrap.php'; | ||
|
||
ini_set('memory_limit', '-1'); | ||
|
||
$detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); | ||
|
||
$url = __DIR__. '/../images/astronaut.png'; | ||
$candidateLabels = ['human face', 'rocket', 'helmet', 'american flag']; | ||
|
||
$url = __DIR__. '/../images/beach.png'; | ||
$candidateLabels = ['hat', 'book', 'sunglasses', 'camera']; | ||
|
||
$output = $detector($url, $candidateLabels, topK: 4, threshold: 0.05); | ||
|
||
dd($output, timeUsage(), memoryUsage()); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\FeatureExtractors; | ||
|
||
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput; | ||
use Codewithkyrian\Transformers\Processors\Processor; | ||
|
||
class OwlViTFeatureExtractor extends ImageFeatureExtractor | ||
{ | ||
/** | ||
* Post-processes the outputs of the model (for object detection). | ||
* @param ObjectDetectionOutput $outputs The outputs of the model that must be post-processed | ||
* @param float $threshold The threshold to use for the scores. | ||
* @param array|null $targetSizes The sizes of the original images. | ||
* @param bool $isZeroShot Whether zero-shot object detection was performed. | ||
* @return array An array of objects containing the post-processed outputs. | ||
*/ | ||
public function postProcessObjectDetection(ObjectDetectionOutput $outputs, float $threshold = 0.5, ?array $targetSizes = null, bool $isZeroShot = false): array | ||
{ | ||
return Processor::postProcessObjectDetection($outputs, $threshold, $targetSizes, $isZeroShot); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\FeatureExtractors; | ||
|
||
class Owlv2ImageProcessor extends OwlViTFeatureExtractor | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Auto; | ||
|
||
class AutoModelForZeroShotObjectDetection extends PretrainedMixin | ||
{ | ||
const MODEL_CLASS_MAPPING = [ | ||
'owlvit' => \Codewithkyrian\Transformers\Models\Pretrained\OwlViTForObjectDetection::class, | ||
'owlv2' => \Codewithkyrian\Transformers\Models\Pretrained\Owlv2ForObjectDetection::class, | ||
]; | ||
|
||
const MODEL_CLASS_MAPPINGS = [ | ||
self::MODEL_CLASS_MAPPING, | ||
]; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput; | ||
|
||
class OwlViTForObjectDetection extends OwlViTPretrainedModel | ||
{ | ||
public function __invoke(array $modelInputs): ObjectDetectionOutput | ||
{ | ||
return ObjectDetectionOutput::fromOutput(parent::__invoke($modelInputs)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
class OwlViTModel extends OwlViTPretrainedModel | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
class OwlViTPretrainedModel extends PretrainedModel | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput; | ||
|
||
class Owlv2ForObjectDetection extends Owlv2PretrainedModel | ||
{ | ||
public function __invoke(array $modelInputs): ObjectDetectionOutput | ||
{ | ||
return ObjectDetectionOutput::fromOutput(parent::__invoke($modelInputs)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
class Owlv2Model extends Owlv2PretrainedModel | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Models\Pretrained; | ||
|
||
class Owlv2PretrainedModel extends PretrainedModel | ||
{ | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Pipelines; | ||
|
||
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput; | ||
use Codewithkyrian\Transformers\Utils\Tensor; | ||
use function Codewithkyrian\Transformers\Utils\getBoundingBox; | ||
use function Codewithkyrian\Transformers\Utils\prepareImages; | ||
|
||
/** | ||
* Zero-shot object detection pipeline. This pipeline predicts bounding boxes of | ||
* objects when you provide an image and a set of `candidate_labels`. | ||
* | ||
* **Example:** Zero-shot object detection w/ `Xenova/owlvit-base-patch32`. | ||
* ```php | ||
* $detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); | ||
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; | ||
* $candidateLabels = ['human face', 'rocket', 'helmet', 'american flag']; | ||
* $output = $detector($url, $candidateLabels); | ||
* // [ | ||
* // [ | ||
* // score: 0.24392342567443848, | ||
* // label: 'human face', | ||
* // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } | ||
* // ], | ||
* // ... | ||
* // ] | ||
* ``` | ||
* | ||
* **Example:** Zero-shot object detection w/ `Xenova/owlvit-base-patch32` (returning top 4 matches and setting a threshold). | ||
* ```javascript | ||
* $detector = pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); | ||
* $url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; | ||
* $candidateLabels = ['hat', 'book', 'sunglasses', 'camera']; | ||
* $output = $detector($url, $candidateLabels, topK : 4, threshold : 0.05); | ||
* // [ | ||
* // [ | ||
* // score: 0.1606510728597641, | ||
* // label: 'sunglasses', | ||
* // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } | ||
* // ], | ||
* // ... | ||
* // ] | ||
* ``` | ||
*/ | ||
class ZeroShotObjectDetectionPipeline extends Pipeline | ||
{ | ||
|
||
public function __invoke(array|string $inputs, ...$args): array | ||
{ | ||
$candidateLabels = $args[0]; | ||
$threshold = $args['threshold'] ?? 0.1; | ||
$topK = $args['topK'] ?? null; | ||
$percentage = $args['percentage'] ?? false; | ||
|
||
$isBatched = is_array($inputs); | ||
|
||
$preparedImages = prepareImages($inputs); | ||
|
||
// Run tokenization | ||
$textInputs = $this->tokenizer->tokenize($candidateLabels, padding: true, truncation: true); | ||
|
||
// Run processor | ||
$modelInputs = ($this->processor)($preparedImages); | ||
|
||
$toReturn = []; | ||
foreach ($preparedImages as $i => $image) { | ||
$imageSize = $percentage ? null : [[$image->height(), $image->width()]]; | ||
$pixelValues = $modelInputs['pixel_values'][$i]; | ||
|
||
$pixelValues = Tensor::fromNdArray($pixelValues)->unsqueeze(0); | ||
|
||
// Run model with both text and pixel inputs | ||
/** @var ObjectDetectionOutput $output */ | ||
$output = $this->model->__invoke(array_merge($textInputs, ['pixel_values' => $pixelValues])); | ||
|
||
// Perform post-processing | ||
$processed = $this->processor->featureExtractor->postProcessObjectDetection($output, $threshold, $imageSize, true)[0]; | ||
|
||
$result = []; | ||
|
||
foreach ($processed['boxes'] as $j => $box) { | ||
$result[] = [ | ||
'score' => $processed['scores'][$j], | ||
'label' => $candidateLabels[$processed['classes'][$j]], | ||
'box' => getBoundingBox($box, !$percentage), | ||
]; | ||
} | ||
// Sort by score | ||
usort($result, fn($a, $b) => $b['score'] <=> $a['score']); | ||
|
||
if ($topK !== null) { | ||
$result = array_slice($result, 0, $topK); | ||
} | ||
|
||
$toReturn[] = $result; | ||
} | ||
|
||
return $isBatched ? $toReturn : $toReturn[0]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
|
||
namespace Codewithkyrian\Transformers\Processors; | ||
|
||
class OwlViTProcessor extends Processor | ||
{ | ||
|
||
} |
Oops, something went wrong.