Skip to content

Commit 6de6662

Browse files
committed
Implement fast 2x2 svd. Add test
1 parent e96e63f commit 6de6662

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

Fastor/expressions/linalg_ops/linalg_ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "Fastor/expressions/linalg_ops/unary_norm_op.h"
1717
#include "Fastor/expressions/linalg_ops/unary_qr_op.h"
1818
#include "Fastor/expressions/linalg_ops/unary_det_op.h"
19+
#include "Fastor/expressions/linalg_ops/unary_svd_op.h"
1920
#include "Fastor/expressions/linalg_ops/binary_cross_op.h"
2021

2122
#include "Fastor/expressions/linalg_ops/linalg_traits.h"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#ifndef UNARY_SVD_OP_H
2+
#define UNARY_SVD_OP_H
3+
4+
#include "Fastor/meta/meta.h"
5+
#include "Fastor/backend/inner.h"
6+
#include "Fastor/backend/lufact.h"
7+
#include "Fastor/simd_vector/SIMDVector.h"
8+
#include "Fastor/tensor/AbstractTensor.h"
9+
#include "Fastor/tensor/Aliasing.h"
10+
#include "Fastor/tensor/Tensor.h"
11+
#include "Fastor/tensor/TensorTraits.h"
12+
#include "Fastor/expressions/expression_traits.h"
13+
#include "Fastor/expressions/linalg_ops/linalg_computation_types.h"
14+
#include "Fastor/expressions/linalg_ops/unary_det_op.h"
15+
#include "Fastor/expressions/linalg_ops/unary_trans_op.h"
16+
17+
18+
namespace Fastor {
19+
20+
// SVD
21+
template<typename T, size_t M, enable_if_t_<M==2, bool> = false >
22+
FASTOR_INLINE void svd(const Tensor<T,M,M> &A, Tensor<T,M,M> &U, Tensor<T,M,M> &S, Tensor<T,M,M> &V) {
23+
24+
constexpr T Epsilon_v = std::numeric_limits<T>::epsilon();
25+
26+
const T f00 = A(0, 0);
27+
const T f01 = A(0, 1);
28+
const T f10 = A(1, 0);
29+
const T f11 = A(1, 1);
30+
31+
// If matrix is diagonal, SVD is trivial
32+
if (std::abs(f01 - f10) < Epsilon_v && std::abs(f01) < Epsilon_v)
33+
{
34+
// Compute U
35+
U(0,0) = f00 < 0 ? -1. : 1.;
36+
U(0,1) = 0.;
37+
U(1,0) = 0.;
38+
U(1,1) = f11 < 0. ? -1. : 1.;
39+
40+
// Compute S
41+
S(0,0) = std::abs(f00);
42+
S(0,1) = 0;
43+
S(1,0) = 0;
44+
S(1,1) = std::abs(f11);
45+
46+
// Compute V
47+
V.eye2();
48+
}
49+
// Otherwise, we need to compute A^T*A
50+
else
51+
{
52+
T j = f00 * f00 + f01 * f01;
53+
T k = f10 * f10 + f11 * f11;
54+
T v_c = f00 * f10 + f01 * f11;
55+
// Check to see if A^T*A is diagonal
56+
if (std::abs(v_c) < Epsilon_v)
57+
{
58+
// Compute S
59+
T s1 = std::sqrt(j);
60+
T s2 = std::abs(j - k) < Epsilon_v ? s1 : std::sqrt(k);
61+
S(0,0) = s1;
62+
S(0,1) = 0;
63+
S(1,0) = 0;
64+
S(1,1) = s2;
65+
66+
// Compute U
67+
U.eye2();
68+
69+
// Compute V
70+
V(0,0) = f00 / s1;
71+
V(0,1) = f10 / s2;
72+
V(1,0) = f01 / s1;
73+
V(1,1) = f11 / s2;
74+
}
75+
// Otherwise, solve quadratic equation for eigenvalues
76+
else
77+
{
78+
T jmk = j - k;
79+
T jpk = j + k;
80+
T root = std::sqrt(jmk * jmk + 4. * v_c * v_c);
81+
T eig1 = (jpk + root) * 0.5;
82+
T eig2 = (jpk - root) * 0.5;
83+
84+
// Compute S
85+
T s1 = std::sqrt(eig1);
86+
T s2 = std::abs(root) < Epsilon_v ? s1 : ( eig2 > 0 ? std::sqrt(eig2) : Epsilon_v);
87+
S(0,0) = s1;
88+
S(0,1) = 0;
89+
S(1,0) = 0;
90+
S(1,1) = s2;
91+
92+
// Compute U - use eigenvectors of A^T*A as U
93+
T v_s = eig1 - j;
94+
T len = std::max(std::sqrt(v_s * v_s + v_c * v_c), Epsilon_v);
95+
v_c /= len;
96+
v_s /= len;
97+
98+
U(0,0) = v_c;
99+
U(0,1) = -v_s;
100+
U(1,0) = v_s;
101+
U(1,1) = v_c;
102+
103+
// Compute V - as A * U / s
104+
const T cc = (f00 * v_c + f10 * v_s) / s1;
105+
const T cs = (f01 * v_c + f11 * v_s) / s1;
106+
if (std::abs(s2) > Epsilon_v)
107+
{
108+
V(0,0) = cc;
109+
V(0,1) = (f10* v_c - f00 * v_s) / s2;
110+
V(1,0) = cs;
111+
V(1,1) = (f11 * v_c - f01 * v_s) / s2;
112+
}
113+
else
114+
{
115+
V(0,0) = cc;
116+
V(0,1) = cs;
117+
V(1,0) = cs;
118+
V(1,1) = -cc;
119+
}
120+
}
121+
}
122+
}
123+
124+
125+
126+
// Signed SVD
127+
template<typename T, size_t M>
128+
FASTOR_INLINE void ssvd(const Tensor<T,M,M> &A, Tensor<T,M,M> &U, Tensor<T,M,M> &S, Tensor<T,M,M> &V) {
129+
130+
// Same as above but avoiding the L matrix
131+
svd(A, U, S, V);
132+
133+
// See where to pull the reflection out of
134+
const T detU = determinant(U);
135+
const T detV = determinant(V);
136+
137+
if (detU >= 0 && detV >= 0)
138+
{
139+
// No reflection svd == svd_rv, return
140+
return;
141+
}
142+
143+
Tensor<T, M, M> L = matmul(U, transpose(V));
144+
const T lastColumn = determinant(L);
145+
146+
if (detU < 0 && detV > 0)
147+
{
148+
U(all, M - 1) *= lastColumn;
149+
}
150+
else if (detU > 0 && detV < 0)
151+
{
152+
V(all, M - 1) *= lastColumn;
153+
}
154+
155+
// Push the reflection to the diagonal
156+
S(M - 1, M - 1) *= lastColumn;
157+
}
158+
//-----------------------------------------------------------------------------------------------------------//
159+
160+
} // end of namespace Fastor
161+
162+
163+
#endif // UNARY_SVD_OP_H

tests/test_linalg/test_linalg.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ void test_linalg() {
136136
FASTOR_EXIT_ASSERT(std::abs(determinant(M) - 216) < Tol);
137137
}
138138

139+
// svd
140+
{
141+
Tensor<T,2,2> A; A = {{23.,3.},{3.5,56.}};
142+
Tensor<T,2,2> U,S,V;
143+
svd(A,U,S,V);
144+
Tensor<T,2,2> rec = U % S % trans(V);
145+
FASTOR_EXIT_ASSERT(std::abs(rec(0,0) - A(0,0)) < BigTol);
146+
FASTOR_EXIT_ASSERT(std::abs(rec(0,1) - A(0,1)) < BigTol);
147+
FASTOR_EXIT_ASSERT(std::abs(rec(1,0) - A(1,0)) < BigTol);
148+
FASTOR_EXIT_ASSERT(std::abs(rec(1,1) - A(1,1)) < BigTol);
149+
}
150+
139151
print(FGRN(BOLD("All tests passed successfully")));
140152

141153
}

0 commit comments

Comments
 (0)