Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensordot and repeat #95

Merged
merged 12 commits into from
Jan 17, 2025
4 changes: 4 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@
trunc,
matmul,
matrix_transpose,
tensordot,
concat,
expand_dims,
flip,
permute_dims,
repeat,
reshape,
roll,
squeeze,
Expand Down Expand Up @@ -263,10 +265,12 @@
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
"concat",
"expand_dims",
"flip",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down
3 changes: 3 additions & 0 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def matmul(self, x, y) -> ndx.Array:
def matrix_transpose(self, x) -> ndx.Array:
return NotImplemented

neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
def tensordot(self, x, y) -> ndx.Array:
return NotImplemented

# searching.py

def argmax(self, x, axis=None, keepdims=False) -> ndx.Array:
Expand Down
4 changes: 4 additions & 0 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ def matmul(self, x, y):
def matrix_transpose(self, x) -> ndx.Array:
return ndx.permute_dims(x, list(range(x.ndim - 2)) + [x.ndim - 1, x.ndim - 2])

@validate_core
def tensordot(self, x, y, axes):
return _via_i64_f64(lambda x, y: opx.tensordot(x, y, axes), [x, y])

# searching.py

@validate_core
Expand Down
19 changes: 19 additions & 0 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def expand_dims(self, x, axis):
)
)

def repeat(self, x, repeats, axis):
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
if axis is None:
x = ndx.reshape(x, [-1])
axis = 0

x_shape = ndx.additional.shape(x)

if isinstance(repeats, int):
# TODO: this case can be optimized by broadcasting and reshaping
repeats = ndx.asarray(repeats)

repeats = ndx.broadcast_to(repeats, ndx.reshape(x_shape[axis], [1]))
indices = ndx.searchsorted(
ndx.cumulative_sum(repeats).astype(ndx.uint64),
ndx.arange(ndx.sum(repeats)),
side="right",
)
return ndx.take(x, indices, axis=axis)
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

def flip(self, x, axis):
if x.ndim == 0:
return x.copy()
Expand Down
16 changes: 16 additions & 0 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,16 @@ def matrix_transpose(x):
return _unary(x.dtype._ops.matrix_transpose, x)


def tensordot(x, y, axes=2):
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
if (out := x.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
if (out := y.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for tensordot: '{x.dtype}' and '{y.dtype}'"
)


# indexing.py


Expand Down Expand Up @@ -630,6 +640,12 @@ def permute_dims(x, axes):
)


def repeat(x, repeats, axis=None):
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
if (out := x.dtype._ops.repeat(x, repeats, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for repeat: '{x.dtype}'")


def reshape(x, shape, *, copy=None):
if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented:
return out
Expand Down
43 changes: 42 additions & 1 deletion ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import builtins
import typing
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -440,6 +440,47 @@ def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray:
return _CoreArray(op.matmul(a.var, b.var))


@eager_propagate
def tensordot(
a: _CoreArray, b: _CoreArray, axes: int | tuple[Sequence[int], Sequence[int]] = 2
) -> _CoreArray:
def letter():
for i in builtins.range(ord("a"), ord("z") + 1):
yield chr(i)
raise ValueError("Exceeded available letters for einsum equation")
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

letter_gen = letter()

if a.ndim == 0 or b.ndim == 0:
return _CoreArray(op.mul(a.var, b.var))

if isinstance(axes, int):
axes = (
[-axes + i for i in builtins.range(axes)],
[i for i in builtins.range(axes)],
)

axes_a, axes_b = axes

axes_a = [(ax + a.ndim) if ax < 0 else ax for ax in axes_a]
axes_b = [(bx + b.ndim) if bx < 0 else bx for bx in axes_b]

a_letters = [next(letter_gen) for _ in builtins.range(a.ndim)]

b_letters = [
a_letters[axes_a[axes_b.index(bx)]] if bx in axes_b else next(letter_gen)
for bx in builtins.range(b.ndim)
]

joint_letters = [let for idx, let in enumerate(a_letters) if idx not in axes_a] + [
let for idx, let in enumerate(b_letters) if idx not in axes_b
]

equation = f"{''.join(a_letters)},{''.join(b_letters)}->{''.join(joint_letters)}"

return _CoreArray(op.einsum([a.var, b.var], equation=equation))


@eager_propagate
def arg_max(
data: _CoreArray, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,3 +1038,57 @@ def test_argmaxmin_unsupported_kernels(func, x):

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))


# Current array-api tests don't include the case min(a.ndim, b.ndim) != 0
@pytest.mark.parametrize(
"a, b, axes",
[
(
np.arange(60).reshape(3, 4, 5),
np.arange(24).reshape(4, 3, 2),
([1, 0], [0, 1]),
),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 2),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 0),
(np.arange(60).reshape(4, 5, 3), np.arange(60).reshape(4, 5, 3), 3),
(np.arange(5).reshape(5), np.arange(5).reshape(5), 1),
],
)
def test_tensordot(a, b, axes):
np_result = np.tensordot(a, b, axes=axes)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b), axes=axes).to_numpy()
assert_array_equal(np_result, ndx_result)


@pytest.mark.parametrize(
"a, b",
[
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3)),
],
)
def test_tensordot_no_axes(a, b):
np_result = np.tensordot(a, b)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b)).to_numpy()
assert_array_equal(np_result, ndx_result)


# Current repeat does not work on the upstream arrayapi tests in the case
# of an empty tensor as https://github.com/onnx/onnx/pull/6570 has not landed in onnx
@pytest.mark.parametrize("lazy_repeats", [False, True])
@pytest.mark.parametrize(
"a, repeats, axis",
[
(np.arange(60).reshape(3, 4, 5), 3, 0),
(np.arange(60).reshape(3, 4, 5), 3, 1),
(np.arange(60).reshape(3, 4, 5), 3, 2),
(np.arange(60).reshape(3, 4, 5), 3, None),
(np.arange(60).reshape(3, 4, 5), [1, 2, 3], 0),
],
)
def test_repeat(a, repeats, axis, lazy_repeats):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "lazy" about the lazy_repeats case here? Even when repeats is int, it's immediately wrapped in ndx.asarray(repeats) in your implementation of repeat.

Copy link
Member

@adityagoel4512 adityagoel4512 Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a test that exercises shape inference for repeat, ideally you should be using a lazy array and just building the ONNX model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of the test is to test that if we actually supply the argument as a python integer, we wrap it correctly. I rewrote the test to be cleaner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neNasko1 please rebase

np_result = np.repeat(a, repeats, axis=axis)
if lazy_repeats or not isinstance(repeats, int):
repeats = ndx.asarray(repeats)
ndx_result = ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy()
assert_array_equal(np_result, ndx_result)
9 changes: 4 additions & 5 deletions xfails.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# segmentation fault: https://github.com/onnx/onnx/pull/6570
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_linalg.py::test_tensordot
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved

array_api_tests/test_constants.py::test_newaxis
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_meshgrid
Expand Down Expand Up @@ -59,16 +63,13 @@ array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-tensordot]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_inspection_functions.py::test_array_namespace_info
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_moveaxis
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_manipulation_functions.py::test_tile
array_api_tests/test_manipulation_functions.py::test_unstack
Expand Down Expand Up @@ -113,9 +114,7 @@ array_api_tests/test_signatures.py::test_func_signature[meshgrid]
array_api_tests/test_signatures.py::test_func_signature[minimum]
array_api_tests/test_signatures.py::test_func_signature[moveaxis]
array_api_tests/test_signatures.py::test_func_signature[real]
array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[signbit]
array_api_tests/test_signatures.py::test_func_signature[tensordot]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
Expand Down
Loading