Skip to content

Commit

Permalink
bugfix: Tensor mean method not working for scalar mean
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeWithKyrian committed Apr 11, 2024
1 parent 794dea5 commit 06940e4
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions src/Utils/Tensor.php
Original file line number Diff line number Diff line change
Expand Up @@ -737,17 +737,23 @@ public function to(int $dtype): static
/**
* Returns the mean value of each row of the tensor in the given dimension dim.
*/
public function mean(?int $dim = null, bool $keepdims = false): static
public function mean(?int $axis = null, bool $keepdims = false): static|float|int
{
$mo = self::getMo();

$ndArray = $mo->mean($this, $dim);
$mean = $mo->mean($this, $axis);

if (!$keepdims) {
array_splice($ndArray->shape, $dim, 1);
if ($mean instanceof NDArray) {
$shape = $mean->shape();

if (!$keepdims) {
array_splice($shape, $axis, 1);
}

return new static($mean->buffer(), $mean->dtype(), $shape, $mean->offset());
}

return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
return $mean;
}

/**
Expand Down Expand Up @@ -973,6 +979,32 @@ public function argMax(?int $axis = null): static|int|float
return $argMax;
}

public function min(?int $axis = null): static|int|float
{
$mo = self::getMo();

$min = $mo->min($this, $axis);

if ($min instanceof NDArray) {
return new static($min->buffer(), $min->dtype(), $min->shape(), $min->offset());
}

return $min;
}

public function argMin(?int $axis = null): static|int|float
{
$mo = self::getMo();

$argMin = $mo->argMin($this, $axis);

if ($argMin instanceof NDArray) {
return new static($argMin->buffer(), $argMin->dtype(), $argMin->shape(), $argMin->offset());
}

return $argMin;
}


public function offsetSet($offset, $value): void
{
Expand Down

0 comments on commit 06940e4

Please sign in to comment.