Skip to content

Commit

Permalink
Add support for tensor softmax for 2D tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Apr 11, 2024
1 parent 9a0fc70 commit 794dea5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
5 changes: 2 additions & 3 deletions src/Pipelines/ZeroShotImageClassificationPipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

namespace Codewithkyrian\Transformers\Pipelines;

use Codewithkyrian\Transformers\Utils\Math;
use Codewithkyrian\Transformers\Utils\Tensor;
use function Codewithkyrian\Transformers\Utils\prepareImages;

Expand Down Expand Up @@ -53,14 +52,14 @@ public function __invoke(array|string $inputs, ...$args): array

$activationFn = $this->model->config['model_type'] === 'siglip' ?
fn(Tensor $batch) => $batch->sigmoid()->toArray() :
fn(Tensor $batch) => Math::softmax($batch->toArray());
fn(Tensor $batch) => $batch->softmax();

// Compare each image with each candidate label
$toReturn = [];

foreach ($output['logits_per_image'] as $batch) {
// Compute softmax per image
$scores = $activationFn(Tensor::fromArray($batch));
$scores = $activationFn($batch);

$result = [];
foreach ($scores as $i => $score) {
Expand Down
7 changes: 2 additions & 5 deletions src/Processors/Processor.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
use Codewithkyrian\Transformers\FeatureExtractors\FeatureExtractor;
use Codewithkyrian\Transformers\Models\Output\ObjectDetectionOutput;
use Codewithkyrian\Transformers\Utils\Math;
use Codewithkyrian\Transformers\Utils\Tensor;
use Exception;

/**
Expand Down Expand Up @@ -70,10 +69,8 @@ public static function postProcessObjectDetection(ObjectDetectionOutput $outputs
}
$probs = $logitSigmoid;
} else {
$mo = Tensor::getMo();

// Get most probable class
$maxIndex = $mo->argMax($logit);
$maxIndex = $logit->argMax();

if ($maxIndex === $numClasses - 1) {
// This is the background class, skip it
Expand All @@ -82,7 +79,7 @@ public static function postProcessObjectDetection(ObjectDetectionOutput $outputs
$indices[] = $maxIndex;

// Compute softmax over classes
$probs = Math::softmax($logit->toArray());
$probs = $logit->softmax();
}

foreach ($indices as $index) {
Expand Down
2 changes: 0 additions & 2 deletions src/Utils/Helpers.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

namespace Codewithkyrian\Transformers\Utils;

use Exception;

function memoryUsage(): string
{
$mem = memory_get_usage(true);
Expand Down
17 changes: 15 additions & 2 deletions src/Utils/Tensor.php
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,22 @@ protected function unflattenArray($flatArray, &$currentIndex, array $shape): arr
* Calculate the softmax of the tensor.
*
*/
public function softmax(): array
public function softmax(): array|static
{
return Math::softmax($this->toArray());
return match ($this->ndim()) {
1 => Math::softmax($this->toArray()),
2 => $this->softmax2D(),
default => throw new InvalidArgumentException("Softmax is only supported for 1D and 2D tensors.")
};
}

protected function softmax2D(): static
{
$mo = self::getMo();

$ndArray = $mo->la()->softmax($this);

return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
}

public function max(?int $axis = null): static|int|float
Expand Down

0 comments on commit 794dea5

Please sign in to comment.