Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensordot and repeat #95

Merged
merged 12 commits into from
Jan 17, 2025
40 changes: 40 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import numpy as np
import pytest
import spox.opset.ai.onnx.v19 as op
from hypothesis import given
cbourjau marked this conversation as resolved.
Show resolved Hide resolved
from hypothesis import strategies as st
from hypothesis.extra.numpy import array_shapes

import ndonnx as ndx
import ndonnx.additional as nda
Expand Down Expand Up @@ -1073,6 +1076,43 @@ def test_tensordot_no_axes(a, b):
assert_array_equal(np_result, ndx_result)


def generate_tensordot_cases():
shape1 = array_shapes(min_dims=1, max_dims=4, min_side=1, max_side=5)
shape2 = array_shapes(min_dims=1, max_dims=4, min_side=1, max_side=5)

def compatible_shapes_and_axes(shape1, shape2):
shape1 = list(shape1)
shape2 = list(shape2)

a = np.random.randint(0, 100, size=shape1)
b = np.random.randint(0, 100, size=shape2)

open = list(np.random.permutation(np.arange(len(shape2))))

axes1, axes2 = [], []

for i, d1 in enumerate(shape1):
for ind, j in enumerate(open):
if d1 != shape2[j]:
continue
open.pop(ind)
axes1.append(i)
axes2.append(j)
break

return a, b, (axes1, axes2)

return st.builds(compatible_shapes_and_axes, shape1, shape2)


@given(data=generate_tensordot_cases())
def test_tensordot_hypothesis(data):
a, b, axes = data
np_result = np.tensordot(a, b, axes=axes)
ndx_result = ndx.tensordot(ndx.asarray(a), ndx.asarray(b), axes=axes).to_numpy()
assert_array_equal(np_result, ndx_result)


# Current repeat does not work on the upstream arrayapi tests in the case
# of an empty tensor as https://github.com/onnx/onnx/pull/6570 has not landed in onnx
@pytest.mark.parametrize("lazy_repeats", [False, True])
Expand Down
1 change: 1 addition & 0 deletions xfails.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# segmentation fault: https://github.com/onnx/onnx/pull/6570
array_api_tests/test_manipulation_functions.py::test_repeat
# segmentation fault: due to onnxruntime not handling einsum with dim 0
array_api_tests/test_linalg.py::test_tensordot
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved

array_api_tests/test_constants.py::test_newaxis
Expand Down
Loading