Skip to content

Commit

Permalink
Add dot and cross product to Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Apr 11, 2024
1 parent 06940e4 commit d014ef0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/Pipelines/ImageToImagePipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public function __invoke(array|string $inputs, ...$args): array|Image
foreach ($outputs['reconstruction'] as $i => $batch) {
$output = $batch->squeeze()
->clamp(0, 1)
->multiplyScalar(255)
->multiply(255)
->round()
->to(NDArray::uint8);

Expand Down
48 changes: 29 additions & 19 deletions src/Utils/Tensor.php
Original file line number Diff line number Diff line change
Expand Up @@ -477,34 +477,21 @@ protected function calcUnsqueezeDims(array $dims, int $dim): array
}

/**
* Add two NDArrays element-wise, A + B
* Add a tensor or scalar to this tensor. If it's a tensor, it must be the same shape, and it performs
* an element-wise addition. If it's a scalar, it adds the scalar to every element in the tensor.
*
* @param Tensor $other The NDArray to add to this NDArray.
* @param Tensor|float|int $other The NDArray to add to this NDArray.
* @return static
*/
public function add(Tensor $other): static
public function add(Tensor|float|int $other): static
{
$mo = self::getMo();

$ndArray = $mo->add($this, $other);
$ndArray = is_scalar($other) ? $mo->op($this, '+', $other) : $mo->add($this, $other);

return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
}

/**
* Return a new Tensor with every element added by a constant.
*
* @param float|int $scalar The constant to add.
* @return static
*/
public function addScalar(float|int $scalar): static
{
$mo = self::getMo();

$ndArray = $mo->op($this, '+', $scalar);

return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
}

/**
* Return a new Tensor with the sigmoid function applied to each element.
Expand All @@ -526,7 +513,7 @@ public function sigmoid(): self
*
* @return self
*/
public function multiplyScalar(float|int $scalar): self
public function multiply(float|int $scalar): self
{
$mo = self::getMo();

Expand All @@ -535,6 +522,29 @@ public function multiplyScalar(float|int $scalar): self
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
}

/**
* Calculate the dot product of this tensor and another tensor.
*/
public function dot(Tensor $other): float
{
$mo = self::getMo();

return $mo->dot($this, $other);
}

/**
* Calculate the cross product of this tensor and another tensor. The shapes of the tensors must be compatible for
* cross product
*/
public function cross(Tensor $other): Tensor
{
$mo = self::getMo();

$crossProduct = $mo->cross($this, $other);

return new static($crossProduct->buffer(), $crossProduct->dtype(), $crossProduct->shape(), $crossProduct->offset());
}

/**
* Return a transposed version of this Tensor.
* @return $this
Expand Down
6 changes: 3 additions & 3 deletions tests/Utils/TensorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

it('can add a scalar to each element of a tensor', function () {
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
$result = $tensor->addScalar(5);
$result = $tensor->add(5);

expect($result)->toBeInstanceOf(Tensor::class)
->and($result->toArray())->toBe([[6, 7], [8, 9]]);
Expand All @@ -82,15 +82,15 @@

it('can multiply each element of a tensor by a scalar', function () {
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
$result = $tensor->multiplyScalar(2);
$result = $tensor->multiply(2);

expect($result)->toBeInstanceOf(Tensor::class)
->and($result->toArray())->toBe([[2.0, 4.0], [6.0, 8.0]]);
});

it('can compute the mean value of each row of the tensor', function () {
$tensor = Tensor::fromArray([[1, 2], [3, 4]]);
$result = $tensor->mean(dim: 1);
$result = $tensor->mean(axis: 1);

expect($result)->toBeInstanceOf(Tensor::class)
->and($result->toArray())->toBe([1.5, 3.5]);
Expand Down

0 comments on commit d014ef0

Please sign in to comment.