diff --git a/ndonnx/__init__.py b/ndonnx/__init__.py index 9587f99..5de7639 100644 --- a/ndonnx/__init__.py +++ b/ndonnx/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause import importlib.metadata @@ -127,10 +127,12 @@ trunc, matmul, matrix_transpose, + tensordot, concat, expand_dims, flip, permute_dims, + repeat, reshape, roll, squeeze, @@ -263,10 +265,12 @@ "trunc", "matmul", "matrix_transpose", + "tensordot", "concat", "expand_dims", "flip", "permute_dims", + "repeat", "reshape", "roll", "squeeze", diff --git a/ndonnx/_core/_interface.py b/ndonnx/_core/_interface.py index 5340f4f..23bb388 100644 --- a/ndonnx/_core/_interface.py +++ b/ndonnx/_core/_interface.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -205,6 +205,9 @@ def matmul(self, x, y) -> ndx.Array: def matrix_transpose(self, x) -> ndx.Array: return NotImplemented + def tensordot(self, x, y) -> ndx.Array: + return NotImplemented + # searching.py def argmax(self, x, axis=None, keepdims=False) -> ndx.Array: @@ -342,6 +345,9 @@ def permute_dims(self, x, axes) -> ndx.Array: def reshape(self, x, shape, *, copy=None) -> ndx.Array: return NotImplemented + def repeat(self, x, repeats, axis=None) -> ndx.Array: + return NotImplemented + def roll(self, x, shift, axis) -> ndx.Array: return NotImplemented diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index f67c8ec..42e5e39 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -359,6 +359,10 @@ def matmul(self, x, y): def matrix_transpose(self, x) -> ndx.Array: return ndx.permute_dims(x, list(range(x.ndim - 2)) + [x.ndim - 1, x.ndim - 2]) + @validate_core + def tensordot(self, x, y, axes): + return _via_i64_f64(lambda x, y: opx.tensordot(x, y, axes), [x, y]) + # searching.py @validate_core diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index ff8142e..88a0285 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -68,6 +68,38 @@ def expand_dims(self, x, axis): ) ) + def repeat(self, x, repeats, axis=None): + if axis is None: + x = ndx.reshape(x, [-1]) + axis = 0 + + x_shape = ndx.additional.shape(x) + + if isinstance(repeats, int): + repeats = ndx.asarray(repeats) + + if repeats.ndim == 0: + indices = ndx.broadcast_to( + ndx.arange(x_shape[axis], dtype=ndx.int64), + ndx.concat( + [ndx.expand_dims(repeats, 0), ndx.expand_dims(x_shape[axis], 0)] + ), + ) + indices = ndx.reshape(ndx.matrix_transpose(indices), [-1]) + elif repeats.ndim == 1: + repeats = ndx.broadcast_to(repeats, ndx.reshape(x_shape[axis], [1])) + indices = ndx.searchsorted( + ndx.cumulative_sum(repeats).astype(ndx.int64), + ndx.arange(ndx.sum(repeats), dtype=ndx.int64), + side="right", + ) + elif repeats.ndim > 1: + raise ValueError( + f"'repeats' should be either 0 or 1 dimensional, but is instead {repeats.ndim}-dimensional" + ) + + return ndx.take(x, indices, axis=axis) + def flip(self, x, axis): if x.ndim == 0: return x.copy() diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 3021b2d..e3b9e17 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -550,6 +550,16 @@ def matrix_transpose(x): return _unary(x.dtype._ops.matrix_transpose, x) +def tensordot(x, y, /, *, axes=2): + if (out := x.dtype._ops.tensordot(x, y, axes)) is not NotImplemented: + return out + if (out := y.dtype._ops.tensordot(x, y, axes)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for tensordot: '{x.dtype}' and '{y.dtype}'" + ) + + # indexing.py @@ -630,6 +640,12 @@ def permute_dims(x, axes): ) +def repeat(x, repeats, /, *, axis=None): + if (out := x.dtype._ops.repeat(x, repeats, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for repeat: '{x.dtype}'") + + def reshape(x, shape, *, copy=None): if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented: return out diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 5056d0a..df6e8cc 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -5,7 +5,7 @@ import builtins import typing -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TypeVar import numpy as np @@ -440,6 +440,49 @@ def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray: return _CoreArray(op.matmul(a.var, b.var)) +@eager_propagate +def tensordot( + a: _CoreArray, b: _CoreArray, axes: int | tuple[Sequence[int], Sequence[int]] = 2 +) -> _CoreArray: + def letter(): + for i in builtins.range(ord("a"), ord("z") + 1): + yield chr(i) + raise ValueError( + "Exceeded available letters for einsum equation in the implementation of 'tensordot': this means that the number of dimensions of 'a' and 'b' are too large" + ) + + letter_gen = letter() + + if a.ndim == 0 or b.ndim == 0: + return _CoreArray(op.mul(a.var, b.var)) + + if isinstance(axes, int): + axes = ( + [-axes + i for i in builtins.range(axes)], + [i for i in builtins.range(axes)], + ) + + axes_a, axes_b = axes + + axes_a = [(ax + a.ndim) if ax < 0 else ax for ax in axes_a] + axes_b = [(bx + b.ndim) if bx < 0 else bx for bx in axes_b] + + a_letters = [next(letter_gen) for _ in builtins.range(a.ndim)] + + b_letters = [ + a_letters[axes_a[axes_b.index(bx)]] if bx in axes_b else next(letter_gen) + for bx in builtins.range(b.ndim) + ] + + joint_letters = [let for idx, let in enumerate(a_letters) if idx not in axes_a] + [ + let for idx, let in enumerate(b_letters) if idx not in axes_b + ] + + equation = f"{''.join(a_letters)},{''.join(b_letters)}->{''.join(joint_letters)}" + + return _CoreArray(op.einsum([a.var, b.var], equation=equation)) + + @eager_propagate def arg_max( data: _CoreArray, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0 diff --git a/skips.txt b/skips.txt index f372e85..f80bc43 100644 --- a/skips.txt +++ b/skips.txt @@ -2,3 +2,9 @@ array_api_tests/test_manipulation_functions.py::test_roll array_api_tests/test_data_type_functions.py::test_broadcast_arrays + +# segmentation fault: https://github.com/onnx/onnx/pull/6570 +array_api_tests/test_manipulation_functions.py::test_repeat + +# segmentation fault: https://github.com/microsoft/onnxruntime/pull/23379 +array_api_tests/test_linalg.py::test_tensordot diff --git a/tests/test_core.py b/tests/test_core.py index 81a9547..2504061 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1045,6 +1045,89 @@ def test_argmaxmin_unsupported_kernels(func, x): getattr(ndx, func.__name__)(ndx.asarray(x)) +@pytest.mark.parametrize( + "a, b, axes", + [ + ( + np.arange(60).reshape(3, 4, 5), + np.arange(24).reshape(4, 3, 2), + ([1, 0], [0, 1]), + ), + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 2), + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 0), + (np.arange(60).reshape(4, 5, 3), np.arange(60).reshape(4, 5, 3), 3), + (np.arange(5).reshape(5), np.arange(5).reshape(5), 1), + (np.arange(36).reshape(6, 6), np.arange(36).reshape(6, 6), 1), + (np.arange(24).reshape(3, 2, 4), np.arange(24).reshape(4, 2, 3), 1), + (np.arange(35).reshape(5, 7), np.arange(35).reshape(7, 5), 1), + (np.arange(35).reshape(7, 5), np.arange(35).reshape(7, 5), 2), + (np.arange(48).reshape(4, 3, 4), np.arange(48).reshape(4, 4, 3), 0), + ( + np.arange(32).reshape(4, 4, 2), + np.arange(32).reshape(2, 4, 4), + ([2, 0], [0, 1]), + ), + (np.arange(30).reshape(3, 10), np.arange(20).reshape(10, 2), ([1], [0])), + ], +) +def test_tensordot(a, b, axes): + np_result = np.tensordot(a, b, axes=axes) + ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b), axes=axes).to_numpy() + assert_array_equal(np_result, ndx_result) + + +@pytest.mark.parametrize( + "a, b", + [ + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3)), + ], +) +def test_tensordot_no_axes(a, b): + np_result = np.tensordot(a, b) + ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b)).to_numpy() + assert_array_equal(np_result, ndx_result) + + +# Current repeat does not work on the upstream arrayapi tests in the case +# of an empty tensor as https://github.com/onnx/onnx/pull/6570 has not landed in onnx +@pytest.mark.parametrize( + "a, repeats, axis", + [ + (np.arange(60).reshape(3, 4, 5), 3, 0), + (np.arange(60).reshape(3, 4, 5), 3, 1), + (np.arange(60).reshape(3, 4, 5), 3, 2), + (np.arange(60).reshape(3, 4, 5), 3, None), + (np.arange(60).reshape(3, 4, 5), np.array(3), 0), + (np.arange(60).reshape(3, 4, 5), np.array(3), 1), + (np.arange(60).reshape(3, 4, 5), np.array(3), 2), + (np.arange(60).reshape(3, 4, 5), np.array(3), None), + (np.arange(60).reshape(3, 4, 5), np.arange(3), 0), + (np.arange(60).reshape(3, 4, 5), np.arange(4), 1), + (np.arange(60).reshape(3, 4, 5), np.arange(5), 2), + (np.arange(60).reshape(3, 4, 5), np.arange(60), None), + ], +) +def test_repeat(a, repeats, axis): + np_result = np.repeat(a, repeats, axis=axis) + if not isinstance(repeats, int): + repeats = ndx.asarray(repeats) + ndx_result = ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy() + assert_array_equal(np_result, ndx_result) + + +@pytest.mark.parametrize( + "a, repeats, axis", + [ + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 4, 5]), 0), + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 20]), None), + (np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([2, 3, 2, 5]), None), + ], +) +def test_repeat_raises(a, repeats, axis): + with pytest.raises(ValueError): + ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy() + + @pytest.mark.parametrize( "x, index", [ diff --git a/xfails.txt b/xfails.txt index 0f06fd8..fa73a81 100644 --- a/xfails.txt +++ b/xfails.txt @@ -63,16 +63,13 @@ array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm] array_api_tests/test_has_names.py::test_has_names[linear_algebra-tensordot] array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot] array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis] -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_has_names.py::test_has_names[manipulation-tile] array_api_tests/test_has_names.py::test_has_names[manipulation-unstack] array_api_tests/test_inspection_functions.py::test_array_namespace_info array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes array_api_tests/test_linalg.py::test_matrix_transpose -array_api_tests/test_linalg.py::test_tensordot array_api_tests/test_linalg.py::test_vecdot array_api_tests/test_manipulation_functions.py::test_moveaxis -array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_manipulation_functions.py::test_squeeze array_api_tests/test_manipulation_functions.py::test_tile array_api_tests/test_manipulation_functions.py::test_unstack @@ -117,9 +114,7 @@ array_api_tests/test_signatures.py::test_func_signature[meshgrid] array_api_tests/test_signatures.py::test_func_signature[minimum] array_api_tests/test_signatures.py::test_func_signature[moveaxis] array_api_tests/test_signatures.py::test_func_signature[real] -array_api_tests/test_signatures.py::test_func_signature[repeat] array_api_tests/test_signatures.py::test_func_signature[signbit] -array_api_tests/test_signatures.py::test_func_signature[tensordot] array_api_tests/test_signatures.py::test_func_signature[tile] array_api_tests/test_signatures.py::test_func_signature[unstack] array_api_tests/test_signatures.py::test_func_signature[vecdot]