Skip to content

Commit 64c8c0f

Browse files
committed
ENH: add tests for eig and eigvals
1 parent 602b412 commit 64c8c0f

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,26 @@ def is_scalar(x):
199199
return isinstance(x, (int, float, complex, bool))
200200

201201

202+
def complex_for_float(dtyp):
203+
"""For a real or complex dtype, return a matching complex dtype."""
204+
if api_version <= '2021.12':
205+
raise TypeError("complex dtypes require api_version >= 2022.12.")
206+
207+
if dtyp not in all_float_dtypes:
208+
raise ValueError(f"expected a real dtype, got {dtyp}.")
209+
210+
if dtyp == xp.float32:
211+
return xp.complex64
212+
elif dtyp == xp.float64:
213+
return xp.complex128
214+
elif dtyp == xp.complex64:
215+
return xp.complex64
216+
elif dtype == xp.complex128:
217+
return xp.complex128
218+
else:
219+
raise ValueError(f"Unknown dtype {dtyp}.")
220+
221+
202222
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
203223
dtype_value_pairs = []
204224
for name, value in mapping.items():

array_api_tests/test_linalg.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,59 @@ def test_eigvalsh(x):
331331

332332
# TODO: Test that res actually corresponds to the eigenvalues of x
333333

334+
335+
@pytest.mark.unvectorized
336+
@pytest.mark.xp_extension('linalg')
337+
@pytest.mark.min_version("2025.12")
338+
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
339+
def test_eig(x):
340+
res = linalg.eig(x)
341+
342+
_test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eig')
343+
344+
eigenvalues = res.eigenvalues
345+
eigenvectors = res.eigenvectors
346+
expected_dtype = dh.complex_for_float(x.dtype)
347+
348+
ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvalues.dtype,
349+
expected=expected_dtype, repr_name="eigenvalues.dtype")
350+
ph.assert_result_shape("eig", in_shapes=[x.shape],
351+
out_shape=eigenvalues.shape,
352+
expected=x.shape[:-1],
353+
repr_name="eigenvalues.shape")
354+
355+
ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvectors.dtype,
356+
expected=expected_dtype, repr_name="eigenvectors.dtype")
357+
ph.assert_result_shape("eig", in_shapes=[x.shape],
358+
out_shape=eigenvectors.shape, expected=x.shape,
359+
repr_name="eigenvectors.shape")
360+
361+
# TODO: Test that eigenvectors are orthonormal.
362+
363+
_test_stacks(lambda x: linalg.eig(x).eigenvectors, x,
364+
res=eigenvectors, dims=2)
365+
366+
# TODO: Test that res actually corresponds to the eigenvalues and
367+
# eigenvectors of x
368+
369+
370+
@pytest.mark.unvectorized
371+
@pytest.mark.xp_extension('linalg')
372+
@pytest.mark.min_version("2025.12")
373+
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
374+
def test_eigvals(x):
375+
res = linalg.eigvals(x)
376+
expected_dtype = dh.complex_for_float(x.dtype)
377+
378+
ph.assert_dtype("eigvals", in_dtype=x.dtype, out_dtype=res.dtype,
379+
expected=expected_dtype, repr_name="eigvals")
380+
ph.assert_result_shape("eigvals", in_shapes=[x.shape],
381+
out_shape=res.shape, expected=x.shape[:-1])
382+
# TODO: Test that res actually corresponds to the eigenvalues of x
383+
384+
_test_stacks(linalg.eigvals, x, res=res, dims=1)
385+
386+
334387
@pytest.mark.unvectorized
335388
@pytest.mark.xp_extension('linalg')
336389
@given(x=invertible_matrices())

0 commit comments

Comments
 (0)