Skip to content

Commit

Permalink
Schur and generalized Schue decomposition (#6)
Browse files Browse the repository at this point in the history
* schur and generalized schur

* fix tests
  • Loading branch information
mohamed82008 committed Oct 16, 2022
1 parent 296453c commit a580f91
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 83 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ julia = "1"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["FiniteDifferences", "Test", "Zygote"]
test = ["FiniteDifferences", "Random", "Test", "Zygote"]
52 changes: 51 additions & 1 deletion src/DifferentiableFactorizations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DifferentiableFactorizations

export diff_qr, diff_cholesky, diff_lu, diff_eigen, diff_svd
export diff_qr, diff_cholesky, diff_lu, diff_eigen, diff_svd, diff_schur

using LinearAlgebra, ImplicitDifferentiation, ComponentArrays, ChainRulesCore

Expand Down Expand Up @@ -120,6 +120,56 @@ function diff_eigen(A, B)
return (; s , V)
end

function schur_conditions(A, Z_T)
(; Z, T) = Z_T
return vcat(
vec(Z' * A * Z - T),
vec(Z' * Z - I + LowerTriangular(T) - Diagonal(T)),
)
end
function schur_forward(A)
schur_res = schur(A)
(; Z, T) = schur_res
return ComponentVector(; Z, T)
end
const _diff_schur = ImplicitFunction(schur_forward, schur_conditions)

function bidiag(v1, v2)
return Bidiagonal(v1, v2, :L)
end
function ChainRulesCore.rrule(::typeof(bidiag), v1, v2)
bidiag(v1, v2), Δ -> begin
NoTangent(), diag(Δ), diag(Δ, -1)
end
end

function gen_schur_conditions(AB, left_right_S_T)
(; left, right, S, T) = left_right_S_T
(; A, B) = AB
return vcat(
vec(left * S * right' - A),
vec(left * T * right' - B),
vec(UpperTriangular(left' * left) - I + LowerTriangular(S) - bidiag(diag(S), diag(S, -1) .+ (diag(S, -1) .* diag(T, 1)))),
vec(UpperTriangular(right' * right) - I + LowerTriangular(T) - Diagonal(T)),
)
end
function gen_schur_forward(AB)
(; A, B) = AB
schur_res = schur(A, B)
(; left, right, S, T) = schur_res
return ComponentVector(; left, right, S, T)
end
const _diff_gen_schur = ImplicitFunction(gen_schur_forward, gen_schur_conditions)

function diff_schur(A, B)
(; left, right, S, T) = _diff_gen_schur(comp_vec(A, B))
return (; left, right, S, T)
end
function diff_schur(A)
(; Z, T) = _diff_schur(A)
return (; Z, T)
end

# SVD

function svd_conditions(A, USV)
Expand Down
222 changes: 141 additions & 81 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,98 +1,158 @@
using DifferentiableFactorizations, Test, Zygote, FiniteDifferences, LinearAlgebra, ComponentArrays
using DifferentiableFactorizations, Test, Zygote, FiniteDifferences, LinearAlgebra, ComponentArrays, Random
Random.seed!(1)

@testset "Cholesky" begin
A = rand(3, 3)

f1(A) = diff_cholesky(A' * A + 2I).U
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < 1e-9
const nreps = 3
const tol = 1e-8

f2(A) = diff_cholesky(A' * A + 2I).L
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < 1e-9
@testset "Cholesky" begin
for _ in 1:nreps
A = rand(3, 3)

f1(A) = diff_cholesky(A' * A + 2I).U
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < tol

f2(A) = diff_cholesky(A' * A + 2I).L
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < tol
end
end

@testset "LU" begin
A = rand(3, 3)

f1(A) = vec(diff_lu(A).U)
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < 1e-9

f2(A) = vec(diff_lu(A).L)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < 1e-9
for _ in 1:nreps
A = rand(3, 3)

f1(A) = vec(diff_lu(A).U)
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < tol

f2(A) = vec(diff_lu(A).L)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < tol
end
end

@testset "QR" begin
A = rand(3, 2)

f1(A) = vec(diff_qr(A).Q)
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < 1e-9

f2(A) = vec(diff_qr(A).R)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < 1e-9
for _ in 1:nreps
A = rand(3, 2)

f1(A) = vec(diff_qr(A).Q)
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < tol

f2(A) = vec(diff_qr(A).R)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < tol
end
end

@testset "Eigen" begin
A = rand(3, 3)
B = rand(3, 3)
AB = ComponentVector(; A, B)

f1(AB) = begin
A = AB.A' * AB.A
B = AB.B' * AB.B + 2I
diff_eigen(A, B).s
end
zjac1 = Zygote.jacobian(f1, AB)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, AB)[1]
@test norm(zjac1 - fjac1) < 1e-9

f2(AB) = begin
A = AB.A' * AB.A
B = AB.B' * AB.B + 2I
vec(diff_eigen(A, B).V)
for _ in 1:nreps
A = rand(3, 3)
B = rand(3, 3)
AB = ComponentVector(; A, B)

f1(AB) = begin
A = AB.A' * AB.A
B = AB.B' * AB.B + 2I
diff_eigen(A, B).s
end
zjac1 = Zygote.jacobian(f1, AB)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, AB)[1]
@test norm(zjac1 - fjac1) < tol

f2(AB) = begin
A = AB.A' * AB.A
B = AB.B' * AB.B + 2I
vec(diff_eigen(A, B).V)
end
zjac2 = Zygote.jacobian(f2, AB)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, AB)[1]
@test norm(zjac2 - fjac2) < tol

f3(A) = diff_eigen(A' * A).s
zjac3 = Zygote.jacobian(f3, A)[1]
fjac3 = FiniteDifferences.jacobian(central_fdm(5, 1), f3, A)[1]
@test norm(zjac3 - fjac3) < tol

# Seems eigen does not guarantee differentiability of the output V without matrix B - the FiniteDifferences jacobian has large numbers

# f4(A) = vec(diff_eigen(A' * A + 5I).V)
# zjac4 = Zygote.jacobian(f4, A)[1]
# fjac4 = FiniteDifferences.jacobian(central_fdm(5, 1), f4, A)[1]
# @test norm(zjac4 - fjac4) < tol
end
zjac2 = Zygote.jacobian(f2, AB)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, AB)[1]
@test norm(zjac2 - fjac2) < 1e-9

f3(A) = diff_eigen(A' * A).s
zjac3 = Zygote.jacobian(f3, A)[1]
fjac3 = FiniteDifferences.jacobian(central_fdm(5, 1), f3, A)[1]
@test norm(zjac3 - fjac3) < 1e-9
end

# Seems eigen does not guarantee differentiability of the output V without matrix B - the FiniteDifferences jacobian has large numbers
@testset "SVD" begin
for _ in 1:nreps
A = rand(3, 3)

f1(A) = diff_svd(A).S
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < tol

f2(A) = vec(diff_svd(A).U)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < tol

f3(A) = vec(diff_svd(A).V)
zjac3 = Zygote.jacobian(f3, A)[1]
fjac3 = FiniteDifferences.jacobian(central_fdm(5, 1), f3, A)[1]
@test norm(zjac3 - fjac3) < tol
end
end

# f4(A) = vec(diff_eigen(A' * A + 5I).V)
# zjac4 = Zygote.jacobian(f4, A)[1]
# fjac4 = FiniteDifferences.jacobian(central_fdm(5, 1), f4, A)[1]
# @test norm(zjac4 - fjac4) < 1e-9
@testset "Schur" begin
for _ in 1:nreps
A = randn(3, 3)
A = A' + A + I
f1(A) = vec(diff_schur(A).Z)
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < tol

f2(A) = vec(diff_schur(A).T)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < tol
end
end

@testset "SVD" begin
A = rand(3, 3)

f1(A) = diff_svd(A).S
zjac1 = Zygote.jacobian(f1, A)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, A)[1]
@test norm(zjac1 - fjac1) < 1e-9

f2(A) = vec(diff_svd(A).U)
zjac2 = Zygote.jacobian(f2, A)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, A)[1]
@test norm(zjac2 - fjac2) < 1e-9

f3(A) = vec(diff_svd(A).V)
zjac3 = Zygote.jacobian(f3, A)[1]
fjac3 = FiniteDifferences.jacobian(central_fdm(5, 1), f3, A)[1]
@test norm(zjac3 - fjac3) < 1e-9
@testset "Generalized Schur" begin
for _ in 1:nreps
A = randn(3, 3)
A = A' + A + I
B = rand(3, 3)
B = B' + B + I
AB = ComponentVector(; A, B)

f1(AB) = vec(diff_schur(AB.A, AB.B).left)
zjac1 = Zygote.jacobian(f1, AB)[1]
fjac1 = FiniteDifferences.jacobian(central_fdm(5, 1), f1, AB)[1]
@test norm(zjac1 - fjac1) < tol

f2(AB) = vec(diff_schur(AB.A, AB.B).right)
zjac2 = Zygote.jacobian(f2, AB)[1]
fjac2 = FiniteDifferences.jacobian(central_fdm(5, 1), f2, AB)[1]
@test norm(zjac2 - fjac2) < tol

f3(AB) = vec(diff_schur(AB.A, AB.B).S)
zjac3 = Zygote.jacobian(f3, AB)[1]
fjac3 = FiniteDifferences.jacobian(central_fdm(5, 1), f3, AB)[1]
@test norm(zjac3 - fjac3) < tol

f4(AB) = vec(diff_schur(AB.A, AB.B).T)
zjac4 = Zygote.jacobian(f4, AB)[1]
fjac4 = FiniteDifferences.jacobian(central_fdm(5, 1), f4, AB)[1]
@test norm(zjac4 - fjac4) < tol
end
end

0 comments on commit a580f91

Please sign in to comment.