Skip to content

Commit 31caaa4

Browse files
committed
ENH: add tests for eig and eigvals
1 parent bf8029a commit 31caaa4

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def real_dtype_for(dtyp):
231231
return real_dtype
232232

233233

234-
235234
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
236235
dtype_value_pairs = []
237236
for name, value in mapping.items():

array_api_tests/test_linalg.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,59 @@ def test_eigvalsh(x):
331331

332332
# TODO: Test that res actually corresponds to the eigenvalues of x
333333

334+
335+
@pytest.mark.unvectorized
336+
@pytest.mark.xp_extension('linalg')
337+
@pytest.mark.min_version("2025.12")
338+
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
339+
def test_eig(x):
340+
res = linalg.eig(x)
341+
342+
_test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eig')
343+
344+
eigenvalues = res.eigenvalues
345+
eigenvectors = res.eigenvectors
346+
expected_dtype = dh.complex_dtype_for(x.dtype)
347+
348+
ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvalues.dtype,
349+
expected=expected_dtype, repr_name="eigenvalues.dtype")
350+
ph.assert_result_shape("eig", in_shapes=[x.shape],
351+
out_shape=eigenvalues.shape,
352+
expected=x.shape[:-1],
353+
repr_name="eigenvalues.shape")
354+
355+
ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvectors.dtype,
356+
expected=expected_dtype, repr_name="eigenvectors.dtype")
357+
ph.assert_result_shape("eig", in_shapes=[x.shape],
358+
out_shape=eigenvectors.shape, expected=x.shape,
359+
repr_name="eigenvectors.shape")
360+
361+
# TODO: Test that eigenvectors are orthonormal.
362+
363+
_test_stacks(lambda x: linalg.eig(x).eigenvectors, x,
364+
res=eigenvectors, dims=2)
365+
366+
# TODO: Test that res actually corresponds to the eigenvalues and
367+
# eigenvectors of x
368+
369+
370+
@pytest.mark.unvectorized
371+
@pytest.mark.xp_extension('linalg')
372+
@pytest.mark.min_version("2025.12")
373+
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
374+
def test_eigvals(x):
375+
res = linalg.eigvals(x)
376+
expected_dtype = dh.complex_dtype_for(x.dtype)
377+
378+
ph.assert_dtype("eigvals", in_dtype=x.dtype, out_dtype=res.dtype,
379+
expected=expected_dtype, repr_name="eigvals")
380+
ph.assert_result_shape("eigvals", in_shapes=[x.shape],
381+
out_shape=res.shape, expected=x.shape[:-1])
382+
# TODO: Test that res actually corresponds to the eigenvalues of x
383+
384+
_test_stacks(linalg.eigvals, x, res=res, dims=1)
385+
386+
334387
@pytest.mark.unvectorized
335388
@pytest.mark.xp_extension('linalg')
336389
@given(x=invertible_matrices())

0 commit comments

Comments
 (0)