Skip to content

Commit

Permalink
Implement tensordot and repeat (#95)
Browse files Browse the repository at this point in the history
Signed-off-by: neNasko1 <[email protected]>
Signed-off-by: Atanas Dimitrov <[email protected]>
Co-authored-by: Aditya Goel <[email protected]>
  • Loading branch information
neNasko1 and adityagoel4512 authored Jan 17, 2025
1 parent 8a8b74f commit c919f11
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 9 deletions.
6 changes: 5 additions & 1 deletion ndonnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import importlib.metadata
Expand Down Expand Up @@ -127,10 +127,12 @@
trunc,
matmul,
matrix_transpose,
tensordot,
concat,
expand_dims,
flip,
permute_dims,
repeat,
reshape,
roll,
squeeze,
Expand Down Expand Up @@ -263,10 +265,12 @@
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
"concat",
"expand_dims",
"flip",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down
8 changes: 7 additions & 1 deletion ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -205,6 +205,9 @@ def matmul(self, x, y) -> ndx.Array:
def matrix_transpose(self, x) -> ndx.Array:
return NotImplemented

def tensordot(self, x, y) -> ndx.Array:
return NotImplemented

# searching.py

def argmax(self, x, axis=None, keepdims=False) -> ndx.Array:
Expand Down Expand Up @@ -342,6 +345,9 @@ def permute_dims(self, x, axes) -> ndx.Array:
def reshape(self, x, shape, *, copy=None) -> ndx.Array:
return NotImplemented

def repeat(self, x, repeats, axis=None) -> ndx.Array:
return NotImplemented

def roll(self, x, shift, axis) -> ndx.Array:
return NotImplemented

Expand Down
4 changes: 4 additions & 0 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def matmul(self, x, y):
def matrix_transpose(self, x) -> ndx.Array:
return ndx.permute_dims(x, list(range(x.ndim - 2)) + [x.ndim - 1, x.ndim - 2])

@validate_core
def tensordot(self, x, y, axes):
return _via_i64_f64(lambda x, y: opx.tensordot(x, y, axes), [x, y])

# searching.py

@validate_core
Expand Down
32 changes: 32 additions & 0 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,38 @@ def expand_dims(self, x, axis):
)
)

def repeat(self, x, repeats, axis=None):
if axis is None:
x = ndx.reshape(x, [-1])
axis = 0

x_shape = ndx.additional.shape(x)

if isinstance(repeats, int):
repeats = ndx.asarray(repeats)

if repeats.ndim == 0:
indices = ndx.broadcast_to(
ndx.arange(x_shape[axis], dtype=ndx.int64),
ndx.concat(
[ndx.expand_dims(repeats, 0), ndx.expand_dims(x_shape[axis], 0)]
),
)
indices = ndx.reshape(ndx.matrix_transpose(indices), [-1])
elif repeats.ndim == 1:
repeats = ndx.broadcast_to(repeats, ndx.reshape(x_shape[axis], [1]))
indices = ndx.searchsorted(
ndx.cumulative_sum(repeats).astype(ndx.int64),
ndx.arange(ndx.sum(repeats), dtype=ndx.int64),
side="right",
)
elif repeats.ndim > 1:
raise ValueError(
f"'repeats' should be either 0 or 1 dimensional, but is instead {repeats.ndim}-dimensional"
)

return ndx.take(x, indices, axis=axis)

def flip(self, x, axis):
if x.ndim == 0:
return x.copy()
Expand Down
18 changes: 17 additions & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -550,6 +550,16 @@ def matrix_transpose(x):
return _unary(x.dtype._ops.matrix_transpose, x)


def tensordot(x, y, /, *, axes=2):
if (out := x.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
if (out := y.dtype._ops.tensordot(x, y, axes)) is not NotImplemented:
return out
raise UnsupportedOperationError(
f"Unsupported operand type for tensordot: '{x.dtype}' and '{y.dtype}'"
)


# indexing.py


Expand Down Expand Up @@ -630,6 +640,12 @@ def permute_dims(x, axes):
)


def repeat(x, repeats, /, *, axis=None):
if (out := x.dtype._ops.repeat(x, repeats, axis)) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for repeat: '{x.dtype}'")


def reshape(x, shape, *, copy=None):
if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented:
return out
Expand Down
45 changes: 44 additions & 1 deletion ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import builtins
import typing
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -440,6 +440,49 @@ def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray:
return _CoreArray(op.matmul(a.var, b.var))


@eager_propagate
def tensordot(
a: _CoreArray, b: _CoreArray, axes: int | tuple[Sequence[int], Sequence[int]] = 2
) -> _CoreArray:
def letter():
for i in builtins.range(ord("a"), ord("z") + 1):
yield chr(i)
raise ValueError(
"Exceeded available letters for einsum equation in the implementation of 'tensordot': this means that the number of dimensions of 'a' and 'b' are too large"
)

letter_gen = letter()

if a.ndim == 0 or b.ndim == 0:
return _CoreArray(op.mul(a.var, b.var))

if isinstance(axes, int):
axes = (
[-axes + i for i in builtins.range(axes)],
[i for i in builtins.range(axes)],
)

axes_a, axes_b = axes

axes_a = [(ax + a.ndim) if ax < 0 else ax for ax in axes_a]
axes_b = [(bx + b.ndim) if bx < 0 else bx for bx in axes_b]

a_letters = [next(letter_gen) for _ in builtins.range(a.ndim)]

b_letters = [
a_letters[axes_a[axes_b.index(bx)]] if bx in axes_b else next(letter_gen)
for bx in builtins.range(b.ndim)
]

joint_letters = [let for idx, let in enumerate(a_letters) if idx not in axes_a] + [
let for idx, let in enumerate(b_letters) if idx not in axes_b
]

equation = f"{''.join(a_letters)},{''.join(b_letters)}->{''.join(joint_letters)}"

return _CoreArray(op.einsum([a.var, b.var], equation=equation))


@eager_propagate
def arg_max(
data: _CoreArray, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
Expand Down
6 changes: 6 additions & 0 deletions skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@

array_api_tests/test_manipulation_functions.py::test_roll
array_api_tests/test_data_type_functions.py::test_broadcast_arrays

# segmentation fault: https://github.com/onnx/onnx/pull/6570
array_api_tests/test_manipulation_functions.py::test_repeat

# segmentation fault: https://github.com/microsoft/onnxruntime/pull/23379
array_api_tests/test_linalg.py::test_tensordot
83 changes: 83 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,89 @@ def test_argmaxmin_unsupported_kernels(func, x):
getattr(ndx, func.__name__)(ndx.asarray(x))


@pytest.mark.parametrize(
"a, b, axes",
[
(
np.arange(60).reshape(3, 4, 5),
np.arange(24).reshape(4, 3, 2),
([1, 0], [0, 1]),
),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 2),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3), 0),
(np.arange(60).reshape(4, 5, 3), np.arange(60).reshape(4, 5, 3), 3),
(np.arange(5).reshape(5), np.arange(5).reshape(5), 1),
(np.arange(36).reshape(6, 6), np.arange(36).reshape(6, 6), 1),
(np.arange(24).reshape(3, 2, 4), np.arange(24).reshape(4, 2, 3), 1),
(np.arange(35).reshape(5, 7), np.arange(35).reshape(7, 5), 1),
(np.arange(35).reshape(7, 5), np.arange(35).reshape(7, 5), 2),
(np.arange(48).reshape(4, 3, 4), np.arange(48).reshape(4, 4, 3), 0),
(
np.arange(32).reshape(4, 4, 2),
np.arange(32).reshape(2, 4, 4),
([2, 0], [0, 1]),
),
(np.arange(30).reshape(3, 10), np.arange(20).reshape(10, 2), ([1], [0])),
],
)
def test_tensordot(a, b, axes):
np_result = np.tensordot(a, b, axes=axes)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b), axes=axes).to_numpy()
assert_array_equal(np_result, ndx_result)


@pytest.mark.parametrize(
"a, b",
[
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape(4, 5, 3)),
],
)
def test_tensordot_no_axes(a, b):
np_result = np.tensordot(a, b)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b)).to_numpy()
assert_array_equal(np_result, ndx_result)


# Current repeat does not work on the upstream arrayapi tests in the case
# of an empty tensor as https://github.com/onnx/onnx/pull/6570 has not landed in onnx
@pytest.mark.parametrize(
"a, repeats, axis",
[
(np.arange(60).reshape(3, 4, 5), 3, 0),
(np.arange(60).reshape(3, 4, 5), 3, 1),
(np.arange(60).reshape(3, 4, 5), 3, 2),
(np.arange(60).reshape(3, 4, 5), 3, None),
(np.arange(60).reshape(3, 4, 5), np.array(3), 0),
(np.arange(60).reshape(3, 4, 5), np.array(3), 1),
(np.arange(60).reshape(3, 4, 5), np.array(3), 2),
(np.arange(60).reshape(3, 4, 5), np.array(3), None),
(np.arange(60).reshape(3, 4, 5), np.arange(3), 0),
(np.arange(60).reshape(3, 4, 5), np.arange(4), 1),
(np.arange(60).reshape(3, 4, 5), np.arange(5), 2),
(np.arange(60).reshape(3, 4, 5), np.arange(60), None),
],
)
def test_repeat(a, repeats, axis):
np_result = np.repeat(a, repeats, axis=axis)
if not isinstance(repeats, int):
repeats = ndx.asarray(repeats)
ndx_result = ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy()
assert_array_equal(np_result, ndx_result)


@pytest.mark.parametrize(
"a, repeats, axis",
[
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 4, 5]), 0),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([3, 20]), None),
(np.arange(60).reshape(3, 4, 5), np.arange(60).reshape([2, 3, 2, 5]), None),
],
)
def test_repeat_raises(a, repeats, axis):
with pytest.raises(ValueError):
ndx.repeat(ndx.asarray(a), repeats, axis=axis).to_numpy()


@pytest.mark.parametrize(
"x, index",
[
Expand Down
5 changes: 0 additions & 5 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,13 @@ array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-tensordot]
array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
array_api_tests/test_has_names.py::test_has_names[manipulation-moveaxis]
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_inspection_functions.py::test_array_namespace_info
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_moveaxis
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_manipulation_functions.py::test_tile
array_api_tests/test_manipulation_functions.py::test_unstack
Expand Down Expand Up @@ -117,9 +114,7 @@ array_api_tests/test_signatures.py::test_func_signature[meshgrid]
array_api_tests/test_signatures.py::test_func_signature[minimum]
array_api_tests/test_signatures.py::test_func_signature[moveaxis]
array_api_tests/test_signatures.py::test_func_signature[real]
array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[signbit]
array_api_tests/test_signatures.py::test_func_signature[tensordot]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
Expand Down

0 comments on commit c919f11

Please sign in to comment.