Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensordot and repeat #95

Merged
merged 12 commits into from
Jan 17, 2025
6 changes: 5 additions & 1 deletion ndonnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import importlib.metadata
Expand Down Expand Up @@ -127,10 +127,12 @@
trunc,
matmul,
matrix_transpose,
tensordot,
concat,
expand_dims,
flip,
permute_dims,
repeat,
reshape,
roll,
squeeze,
Expand Down Expand Up @@ -263,10 +265,12 @@
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
"concat",
"expand_dims",
"flip",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down
8 changes: 7 additions & 1 deletion ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -205,6 +205,9 @@ def matmul(self, x, y) -> ndx.Array:
def matrix_transpose(self, x) -> ndx.Array:
return NotImplemented

def tensordot(self, x, y) -> ndx.Array:
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved
return NotImplemented

# searching.py

def argmax(self, x, axis=None, keepdims=False) -> ndx.Array:
Expand Down Expand Up @@ -342,6 +345,9 @@ def permute_dims(self, x, axes) -> ndx.Array:
def reshape(self, x, shape, *, copy=None) -> ndx.Array:
return NotImplemented

def repeat(self, x, repeats, axis=None) -> ndx.Array:
return NotImplemented

def roll(self, x, shift, axis) -> ndx.Array:
return NotImplemented

Expand Down
4 changes: 4 additions & 0 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def matmul(self, x, y):
def matrix_transpose(self, x) -> ndx.Array:
return ndx.permute_dims(x, list(range(x.ndim - 2)) + [x.ndim - 1, x.ndim - 2])

@validate_core
def tensordot(self, x, y, axes):
return _via_i64_f64(lambda x, y: opx.tensordot(x, y, axes), [x, y])

# searching.py

@validate_core
Expand Down
32 changes: 32 additions & 0 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,38 @@ def expand_dims(self, x, axis):
)
)

def repeat(self, x, repeats, axis=None):
if axis is None:
x = ndx.reshape(x, [-1])
axis = 0

x_shape = ndx.additional.shape(x)

if isinstance(repeats, int):
repeats = ndx.asarray(repeats)

if repeats.ndim == 0:
indices = ndx.broadcast_to(
ndx.arange(x_shape[axis], dtype=ndx.int64),
ndx.concat(
[ndx.expand_dims(repeats, 0), ndx.expand_dims(x_shape[axis], 0)]
),
)
indices = ndx.reshape(ndx.matrix_transpose(indices), [-1])
elif repeats.ndim == 1:
repeats = ndx.broadcast_to(repeats, ndx.reshape(x_shape[axis], [1]))
indices = ndx.searchsorted(
ndx.cumulative_sum(repeats).astype(ndx.int64),
ndx.arange(ndx.sum(repeats), dtype=ndx.int64),
side="right",
)
elif repeats.ndim > 1:
raise ValueError(
f"'repeats' should be either 0 or 1 dimensional, but is instead {repeats.ndim}-dimensional"
)

return ndx.take(x, indices, axis=axis)
neNasko1 marked this conversation as resolved.
Show resolved Hide resolved

def flip(self, x, axis):
if x.ndim == 0:
return x.copy()
Expand Down
18 changes: 17 additions & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -550,6 +550,16 @@ def matrix_transpose(x):
return _unary(x.dtype._ops.matrix_transpose, x)


def tensordot(x, y, /, *, axes=2):
if (out := x.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
if (out := y.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for tensordot: '{x.dtype}' and '{y.dtype}'"
)


# indexing.py


Expand Down Expand Up @@ -630,6 +640,12 @@ def permute_dims(x, axes):
)


def repeat(x, repeats, /, *, axis=None):
if (out := x.dtype._ops.repeat(x, repeats, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for repeat: '{x.dtype}'")


def reshape(x, shape, *, copy=None):
if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented:
return out
Expand Down
45 changes: 44 additions & 1 deletion ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import builtins
import typing
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -440,6 +440,49 @@ def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray:
return _CoreArray(op.matmul(a.var, b.var))


@eager_propagate
def tensordot(
a: _CoreArray, b: _CoreArray, axes: int | tuple[Sequence[int], Sequence[int]] = 2
) -> _CoreArray:
def letter():
for i in builtins.range(ord("a"), ord("z") + 1):
yield chr(i)
raise ValueError(
"Exceeded available letters for einsum equation in the implementation of 'tensordot': this means that the number of dimensions of 'a' and 'b' are too large"
)

letter_gen = letter()

if a.ndim == 0 or b.ndim == 0:
return _CoreArray(op.mul(a.var, b.var))

if isinstance(axes, int):
axes = (
[-axes + i for i in builtins.range(axes)],
[i for i in builtins.range(axes)],
)

axes_a, axes_b = axes

axes_a = [(ax + a.ndim) if ax < 0 else ax for ax in axes_a]
axes_b = [(bx + b.ndim) if bx < 0 else bx for bx in axes_b]

a_letters = [next(letter_gen) for _ in builtins.range(a.ndim)]

b_letters = [
a_letters[axes_a[axes_b.index(bx)]] if bx in axes_b else next(letter_gen)
for bx in builtins.range(b.ndim)
]

joint_letters = [let for idx, let in enumerate(a_letters) if idx not in axes_a] + [
let for idx, let in enumerate(b_letters) if idx not in axes_b
]

equation = f"{''.join(a_letters)},{''.join(b_letters)}->{''.join(joint_letters)}"

return _CoreArray(op.einsum([a.var, b.var], equation=equation))


@eager_propagate
def arg_max(
data: _CoreArray, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
Expand Down
6 changes: 6 additions & 0 deletions skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@

array_api_tests/test_manipulation_functions.py::test_roll
array_api_tests/test_data_type_functions.py::test_broadcast_arrays

# segmentation fault: https://github.com/onnx/onnx/pull/6570
array_api_tests/test_manipulation_functions.py::test_repeat

# segmentation fault: due to onnxruntime not handling einsum with dim 0
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
array_api_tests/test_linalg.py::test_tensordot
82 changes: 82 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,88 @@ def test_argmaxmin_unsupported_kernels(func, x):
getattr(ndx, func.__name__)(ndx.asarray(x))


@pytest.mark.parametrize(
"a, b, axes",
[
(
np.arange(60).reshape(3, 4, 5),
np.arange(24).reshape(4, 3, 2),
([1, 0], [0, 1]),
),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 2),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 0),
(np.arange(60).reshape(4, 5, 3), np.arange(60).reshape(4, 5, 3), 3),
(np.arange(5).reshape(5), np.arange(5).reshape(5), 1),
(np.arange(36).reshape(6, 6), np.arange(36).reshape(6, 6), 1),
(np.arange(24).reshape(3, 2, 4), np.arange(24).reshape(4, 2, 3), 1),
(np.arange(35).reshape(5, 7), np.arange(35).reshape(7, 5), 1),
(np.arange(35).reshape(7, 5), np.arange(35).reshape(7, 5), 2),
(np.arange(48).reshape(4, 3, 4), np.arange(48).reshape(4, 4, 3), 0),
(
np.arange(32).reshape(4, 4, 2),
np.arange(32).reshape(2, 4, 4),
([2, 0], [0, 1]),
),
(np.arange(30).reshape(3, 10), np.arange(20).reshape(10, 2), ([1], [0])),
],
)
def test_tensordot(a, b, axes):
np_result = np.tensordot(a, b, axes=axes)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b), axes=axes).to_numpy()
assert_array_equal(np_result, ndx_result)


@pytest.mark.parametrize(
"a, b",
[
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3)),
],
)
def test_tensordot_no_axes(a, b):
np_result = np.tensordot(a, b)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b)).to_numpy()
assert_array_equal(np_result, ndx_result)


# Current repeat does not work on the upstream arrayapi tests in the case
# of an empty tensor as https://github.com/onnx/onnx/pull/6570 has not landed in onnx
@pytest.mark.parametrize(
"a, repeats, axis",
[
(np.arange(60).reshape(3, 4, 5), 3, 0),
(np.arange(60).reshape(3, 4, 5), 3, 1),
(np.arange(60).reshape(3, 4, 5), 3, 2),
(np.arange(60).reshape(3, 4, 5), 3, None),
(np.arange(60).reshape(3, 4, 5), np.array(3), 0),
(np.arange(60).reshape(3, 4, 5), np.array(3), 1),
(np.arange(60).reshape(3, 4, 5), np.array(3), 2),
(np.arange(60).reshape(3, 4, 5), np.array(3), None),
(np.arange(60).reshape(3, 4, 5), np.arange(3), 0),
(np.arange(60).reshape(3, 4, 5), np.arange(4), 1),
(np.arange(60).reshape(3, 4, 5), np.arange(5), 2),
(np.arange(60).reshape(3, 4, 5), np.arange(60), None),
],
)
def test_repeat(a, repeats, axis):
np_result = np.repeat(a, repeats, axis=axis)
if not isinstance(repeats, int):
repeats = ndx.asarray(repeats)
ndx_result = ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy()
assert_array_equal(np_result, ndx_result)


@pytest.mark.parametrize(
"a, repeats, axis",
[
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 4, 5]), 0),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 20]), None),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([2, 3, 2, 5]), None),
],
)
def test_repeat_raises(a, repeats, axis):
with pytest.raises(ValueError):
ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy()

adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"x, index",
[
Expand Down
5 changes: 0 additions & 5 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,13 @@ array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-tensordot]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_inspection_functions.py::test_array_namespace_info
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_moveaxis
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_manipulation_functions.py::test_tile
array_api_tests/test_manipulation_functions.py::test_unstack
Expand Down Expand Up @@ -117,9 +114,7 @@ array_api_tests/test_signatures.py::test_func_signature[meshgrid]
array_api_tests/test_signatures.py::test_func_signature[minimum]
array_api_tests/test_signatures.py::test_func_signature[moveaxis]
array_api_tests/test_signatures.py::test_func_signature[real]
array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[signbit]
array_api_tests/test_signatures.py::test_func_signature[tensordot]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
Expand Down
Loading