Skip to content

Commit

Permalink
Add pinv and fix solve for numpy>=2.1 (#63)
Browse files Browse the repository at this point in the history
* add pinv

* fix failing tests with solve

* format

* update changelog

* prepare release

* update CI
  • Loading branch information
OriolAbril authored Sep 19, 2024
1 parent d21153a commit e1892a3
Show file tree
Hide file tree
Showing 15 changed files with 1,006 additions and 465 deletions.
6 changes: 6 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install build dependencies
run: python -m pip install build
- name: Build package
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
6 changes: 3 additions & 3 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Change Log

## v0.x.x (Unreleased)
## v0.8.0 (2024 Sep 19)
### New features
* Add `numpy.linalg.pinv` wrapper {pull}`63`

### Maintenance and fixes

### Documentation
* Update to handle modified behaviour of `numpy.linalg.solve` {pull}`63`

## v0.7.0 (2024 Jan 17)
### New features
Expand Down
497 changes: 316 additions & 181 deletions docs/source/tutorials/linalg_tutorial.ipynb

Large diffs are not rendered by default.

389 changes: 316 additions & 73 deletions docs/source/tutorials/np_linalg_tutorial_port.ipynb

Large diffs are not rendered by default.

450 changes: 256 additions & 194 deletions docs/source/tutorials/stats_tutorial.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi"
name = "xarray-einstats"
description = "Stats, linear algebra and einops for xarray"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "ArviZ team", email = "[email protected]"}
Expand All @@ -27,8 +27,8 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"numpy>=1.22",
"scipy>=1.8",
"numpy>=1.23",
"scipy>=1.9",
"xarray>=2022.09.0",
]

Expand Down
2 changes: 1 addition & 1 deletion src/xarray_einstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"EinopsAccessor",
]

__version__ = "0.8.0.dev0"
__version__ = "0.8.0"


def sort(da, dim, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/accessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Accessors for xarray_einstats features."""

import xarray as xr

from .linalg import (
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
example usage.
"""

import warnings
from collections.abc import Hashable

Expand Down
90 changes: 84 additions & 6 deletions src/xarray_einstats/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
``matmul`` and ``get_default_dims``.
"""

import warnings

import numpy as np
Expand All @@ -34,6 +35,7 @@
"diagonal",
"solve",
"inv",
"pinv",
]


Expand Down Expand Up @@ -709,19 +711,75 @@ def solve(da, db, dims=None, **kwargs):
"""Wrap :func:`numpy.linalg.solve`.
Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
Parameters
----------
da : DataArray
db : DataArray
dims : sequence of hashable, optional
It can have either length 2 or 3. If length 2, both dimensions should have the
same length and be present in `da`, and only one of them should also be present in `db`.
If length 3, the first two elements behave the same; the third element is a dimension
of arbitrary length which can only present in `db`.
From NumPy's docstring, a has ``(..., M, M)`` shape and b has ``(M,) or (..., M, K)``.
Here, b can be ``(..., M)`` this case is not limited to 1d, so dims with length two
indicates the two dimensions of length M, with length 3 it is something like (M, M, K),
which can be done thanks to named dimensions.
**kwargs : mapping
Passed to :func:`xarray.apply_ufunc`
Examples
--------
Dimension naming conventions are designed to ease inverse operation with :func:`xarray.dot`.
The following example illustrates what this means and how to check that solve
worked correctly
.. jupyter-execute::
import xarray as xr
import numpy as np
from xarray_einstats.linalg import solve
from xarray_einstats.tutorial import generate_matrices_dataarray
matrices = generate_matrices_dataarray()
matrices
.. jupyter-execute::
b = matrices.std("dim2") # dims (batch, experiment, dim)
y2 = solve(matrices, b, dims=("dim", "dim2")) # dims (batch, experiment, dim2)
np.allclose(b, xr.dot(matrices, y2, dims="dim2"))
"""
if dims is None:
dims = _attempt_default_dims("solve", da.dims, db.dims)
if len(dims) == 3:
b_dim = dims[0] if dims[0] in db.dims else dims[1]
in_dims = [dims[:2], [b_dim, dims[-1]]]
out_dims = [[b_dim, dims[-1]]]
# solve(a, b) in numpy has signature a: (..., M, M) and b: (..., M, K)
# we look which dim is in b -> represents the M
k_dim = dims[-1] # the last element in dims represents the K
remove_k = False
if k_dim in da:
raise ValueError(
f"Found {k_dim} in `da`. If provided, the 3rd element of 'dims' "
"can only be in `db`."
)
else:
in_dims = [dims, dims[:1]]
out_dims = [dims[:1]]
return xr.apply_ufunc(
# a: (..., M, M) and b: (..., M) is not supported, so we add a dummy K
k_dim = "__k_aux_dim__"
remove_k = True
db = db.expand_dims(k_dim)
b_dim = dims[0] if dims[0] in db.dims else dims[1]
y_dim = dims[1] if dims[0] in db.dims else dims[0]
in_dims = [dims[:2], [b_dim, k_dim]]
out_dims = [[y_dim, k_dim]]
da_out = xr.apply_ufunc(
np.linalg.solve, da, db, input_core_dims=in_dims, output_core_dims=out_dims, **kwargs
)
if remove_k:
return da_out.squeeze(k_dim, drop=True)
return da_out


def inv(da, dims=None, **kwargs):
Expand All @@ -734,3 +792,23 @@ def inv(da, dims=None, **kwargs):
return xr.apply_ufunc(
np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
)


def pinv(da, dims=None, **kwargs):
"""Wrap :func:`numpy.linalg.pinv`.
Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
If both "rtol" and "rcond" are provided, "rtol" will be ignored.
"""
if dims is None:
dims = _attempt_default_dims("pinv", da.dims)
rcond = kwargs.pop("rtol", None)
rcond = kwargs.pop("rcond", rcond)
return xr.apply_ufunc(
np.linalg.pinv,
da,
rcond,
input_core_dims=[dims, []],
output_core_dims=[dims[::-1]],
**kwargs,
)
1 change: 1 addition & 0 deletions src/xarray_einstats/numba.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module with numba enhanced functions."""

import numba
import numpy as np
import xarray as xr
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/tutorial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tutorial module with data for docs and quick testing."""

import numpy as np
import xarray as xr

Expand Down
17 changes: 15 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
matrix_rank,
matrix_transpose,
norm,
pinv,
qr,
slogdet,
solve,
Expand Down Expand Up @@ -140,6 +141,12 @@ def test_inv(self, matrices):
assert out.shape == matrices.shape
assert out.dims == matrices.dims

def test_pinv(self, matrices):
out = pinv(matrices, dims=("experiment", "dim"))
out_dims_exp = ("batch", "dim2", "dim", "experiment")
assert out.dims == out_dims_exp
assert out.shape == tuple(out.sizes[dim] for dim in out_dims_exp)

def test_transpose(self, hermitian):
assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2")))

Expand Down Expand Up @@ -272,10 +279,16 @@ def test_slogdet_det(self, matrices):
det_da = det(matrices, dims=("dim", "dim2"))
assert_allclose(sign * np.exp(logdet), det_da)

def test_solve(self, matrices):
def test_solve_two_dims(self, matrices):
b = matrices.std("dim2")
y = solve(matrices, b, dims=("dim", "dim2"))
assert_allclose(b, xr.dot(matrices, y.rename(dim="dim2"), dims="dim2"), atol=1e-14)
assert_allclose(b, xr.dot(matrices, y, dim="dim2"), atol=1e-14)

def test_solve_three_dims(self, matrices):
b = matrices.std("dim2")
a = matrices.isel(batch=0)
y = solve(a, b, dims=("dim", "dim2", "batch"))
assert_allclose(b, xr.dot(a, y, dim="dim2").transpose(*b.dims), atol=1e-14)

def test_diagonal(self, matrices):
idx = xr.DataArray(np.arange(len(matrices["dim"])), dims="pointwise_sel")
Expand Down

0 comments on commit e1892a3

Please sign in to comment.