Skip to content

Commit

Permalink
Add softmax, max and argmax to Tensor class
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Apr 10, 2024
1 parent c184b25 commit 215f857
Show file tree
Hide file tree
Showing 4 changed files with 472 additions and 443 deletions.
12 changes: 6 additions & 6 deletions src/Pipelines/QuestionAnsweringPipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ public function __invoke(array|string $inputs, ...$args): array

$inputs = $this->tokenizer->__invoke($question, $context, padding: true, truncation: true);

/** @var QuestionAnsweringModelOutput $outputs */
$outputs = $this->model->__invoke($inputs);
/** @var QuestionAnsweringModelOutput $outputs */
$outputs = $this->model->__invoke($inputs);

$toReturn = [];

for ($i = 0; $i < $outputs->startLogits->shape()[0]; ++$i) {
$ids = $inputs['input_ids']->toArray()[$i];
$sepIndex = array_search($this->tokenizer->sepTokenId, $ids);

$startLogits = $outputs->startLogits[$i]->buffer()->toArray();
$endLogits = $outputs->endLogits[$i]->buffer()->toArray();
$startLogits = $outputs->startLogits[$i];
$endLogits = $outputs->endLogits[$i];

// Compute softmax for start and end logits and filter based on separator index
$s1 = array_filter(
array_map(
fn($x) => [$x[0], $x[1]],
array_map(null, Math::softmax($startLogits), range(0, count($startLogits) - 1))
array_map(null, $startLogits->softmax(), range(0, count($startLogits) - 1))
),
fn($x) => $x[1] > $sepIndex
);
Expand All @@ -56,7 +56,7 @@ public function __invoke(array|string $inputs, ...$args): array
$e1 = array_filter(
array_map(
fn($x) => [$x[0], $x[1]],
array_map(null, Math::softmax($endLogits), range(0, count($endLogits) - 1))
array_map(null, $endLogits->softmax(), range(0, count($endLogits) - 1))
),
fn($x) => $x[1] > $sepIndex
);
Expand Down
4 changes: 2 additions & 2 deletions src/Pipelines/TextClassificationPipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ public function __invoke(array|string $inputs, ...$args): array

$activationFunction = $problemType == 'multi_label_classification' ?
fn(Tensor $batch) => $batch->sigmoid()->toArray() :
fn(Tensor $batch) => Math::softmax($batch->toArray());
fn(Tensor $batch) => $batch->softmax();

$id2label = $this->model->config['id2label'];
$toReturn = [];

foreach ($outputs->logits as $batch) {
$output = $activationFunction(Tensor::fromArray($batch));
$output = $activationFunction($batch);
$scores = Math::getTopItems($output, $topK);

$values = array_map(function ($score) use ($id2label) {
Expand Down
32 changes: 15 additions & 17 deletions src/Pipelines/TokenClassificationPipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

use Codewithkyrian\Transformers\Models\Output\TokenClassifierOutput;
use Codewithkyrian\Transformers\Utils\AggregationStrategy;
use Codewithkyrian\Transformers\Utils\Math;
use Codewithkyrian\Transformers\Utils\Tensor;
use Exception;

/**
* Named Entity Recognition pipeline using any `ModelForTokenClassification`.
Expand Down Expand Up @@ -73,8 +72,7 @@ public function __invoke(array|string $inputs, ...$args): array
$entities = [];

for ($j = 0; $j < $batch->shape()[0]; ++$j) {
$tokenData = $batch[$j]->toArray();
$topScoreIndex = array_search(max($tokenData), $tokenData);
$topScoreIndex = $batch[$j]->argMax();

$entity = $id2label[$topScoreIndex] ?? "LABEL_{$topScoreIndex}";

Expand All @@ -86,7 +84,7 @@ public function __invoke(array|string $inputs, ...$args): array
continue;
}

$scores = Math::softmax($tokenData);
$scores = $batch[$j]->softmax();

$entities[] = [
'entity' => $entity,
Expand Down Expand Up @@ -168,7 +166,7 @@ protected function aggregateWord(array $entities, AggregationStrategy $aggregati
break;

default:
throw new \Exception("Invalid aggregation_strategy");
throw new Exception("Invalid aggregation_strategy");
}


Expand Down Expand Up @@ -221,6 +219,17 @@ protected function groupEntities($entities): array
return $entityGroups;
}

protected function getTag($entityName): array
{
if (str_starts_with($entityName, "B-")) {
return ["B", substr($entityName, 2)];
} elseif (str_starts_with($entityName, "I-")) {
return ["I", substr($entityName, 2)];
} else {
return ["I", $entityName]; // Default to "I" for continuation
}
}

/**
* Group together the adjacent tokens with the same entity predicted.
*
Expand All @@ -244,15 +253,4 @@ public function groupSubEntities(array $entities): array
];
}

protected function getTag($entityName): array
{
if (str_starts_with($entityName, "B-")) {
return ["B", substr($entityName, 2)];
} elseif (str_starts_with($entityName, "I-")) {
return ["I", substr($entityName, 2)];
} else {
return ["I", $entityName]; // Default to "I" for continuation
}
}

}
Loading

0 comments on commit 215f857

Please sign in to comment.