-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
83 lines (76 loc) · 2.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import scipy.linalg
import numpy as np
# from https://github.com/scipy/scipy/pull/3556/files
def funm_psd(A, func, check_finite=True):
"""
Evaluate a matrix function of a positive semi-definite matrix.
Returns the value of matrix-valued function ``f`` at `A`. The
function ``f`` is an extension of the scalar-valued function `func`
to matrices.
Parameters
----------
A : (N, N) array_like
Positive semi-definite matrix.
func : callable
Callable object that evaluates a scalar function f.
Must be vectorized (eg. using vectorize).
check_finite : boolean, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
ret : (N, N) ndarray
Value of the matrix function at `A`.
See also
--------
funm : Evaluate a matrix function without the psd restriction.
Examples
--------
>>> from scipy import linalg
>>> a = np.array([[1, 2], [2, 4]])
>>> r = linalg.funm_psd(a, np.sqrt)
>>> r
array([[ 0.4472136 , 0.89442719],
[ 0.89442719, 1.78885438]])
>>> r.dot(r)
array([[ 1., 2.],
[ 2., 4.]])
"""
A = np.asarray(A)
if len(A.shape) != 2:
raise ValueError("Non-matrix input to matrix function.")
w, v = scipy.linalg.eigh(A, check_finite=check_finite)
w = np.maximum(w, 0)
return (v * func(w)).dot(v.conj().T)
def sqrtm_psd(A, check_finite=True):
"""
Matrix square root of a positive semi-definite matrix.
Parameters
----------
A : (N, N) array_like
Positive semi-definite matrix whose square root to evaluate.
check_finite : boolean, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
sqrtm : (N, N) ndarray
Value of the sqrt function at `A`.
See also
--------
sqrtm : Matrix square root without the psd restriction.
Examples
--------
>>> from scipy import linalg
>>> a = np.array([[1, 2], [2, 4]])
>>> r = scipy.linalg.sqrtm_psd(a)
>>> r
array([[ 0.4472136 , 0.89442719],
[ 0.89442719, 1.78885438]])
>>> r.dot(r)
array([[ 1., 2.],
[ 2., 4.]])
"""
return funm_psd(A, np.sqrt, check_finite)