Skip to content

Commit

Permalink
change verbosity behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jan 7, 2025
1 parent 0b45919 commit bd61266
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 26 deletions.
18 changes: 9 additions & 9 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
b = (zerovector(v), convert(T, Δλ))
else
vdΔv = inner(v, Δv)
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
gauge = abs(imag(vdΔv))
gauge > alg_primal.tol &&
@warn "`eigsolve` cotangent for eigenvector $i is sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -152,9 +152,9 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (y1, y2)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 && alg_primal.verbosity >= 1
@warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
elseif abs(w[2]) > alg_rrule.tol && alg_rrule.verbosity >= 0
elseif abs(w[2]) > alg_rrule.tol && alg_primal.verbosity >= 1
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $(w[2])"
end
ws[i] = w[1]
Expand Down Expand Up @@ -185,7 +185,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,

# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = VdΔV[mask] - Diagonal(real(diag(VdΔV)))[mask]
Δgauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -263,7 +263,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, conj.(vals) .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n && alg_primal.verbosity >= 1
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end
# cleanup and construct final result by renormalising the eigenvectors and explicitly
Expand All @@ -276,7 +276,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
w, x = Ws[ic]
factor = 1 / x[i]
x[i] = zero(x[i])
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
error = max(norm(x, Inf), abs(rvals[ic] - conj(vals[i])))
error > 10 * tol &&
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $error"
Expand Down Expand Up @@ -308,7 +308,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -366,7 +366,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, vals .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n && alg_primal.verbosity >= 1
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end

Expand All @@ -380,7 +380,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
factor = 1 / x[ic]
x[ic] = zero(x[ic])
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= 1
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
Expand Down
2 changes: 1 addition & 1 deletion ext/KrylovKitChainRulesCoreExt/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function make_linsolve_pullback(fᴴ, b, a₀, a₁, alg_rrule, construct∂f, x
a₁)))
∂b, reverse_info = linsolve(fᴴ, x̄, x̄₀, alg_rrule, conj(a₀),
conj(a₁))
if info.converged > 0 && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged > 0 && reverse_info.converged == 0 && alg_primal.verbosity >= 1
@warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem did: normres = $(reverse_info.normres)"
end
x∂b = inner(x, ∂b)
Expand Down
12 changes: 6 additions & 6 deletions ext/KrylovKitChainRulesCoreExt/svdsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
udΔu = inner(u, Δu)
vdΔv = inner(v, Δv)
if (udΔu isa Complex) || (vdΔv isa Complex)
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
gauge = abs(imag(udΔu + vdΔv))
gauge > alg_primal.tol &&
@warn "`svdsolve` cotangents for singular vectors $i are sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -131,7 +131,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 && alg_primal.verbosity >= 0
@warn "`svdsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
end
x = VectorInterface.add!!(x, u, Δs / 2)
Expand Down Expand Up @@ -162,7 +162,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= 1
mask = abs.(vals' .- vals) .< tol
gaugepart = view(aUdΔU, mask) + view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -227,7 +227,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′, vals .* z)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n && alg_primal.verbosity >= 1
@warn "`svdsolve` cotangent problem did not converge, whereas the primal singular value problem did"
end

Expand All @@ -236,13 +236,13 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
for i in 1:n
x, y, z = Ws[i]
_, ic = findmax(abs, z)
if ic != i
if ic != i && alg_primal.verbosity >= 1
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result"
end
factor = 1 / z[ic]
z[ic] = zero(z[ic])
error = max(norm(z, Inf), abs(rvals[i] - vals[ic]))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= 1
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result: error = $error vs tol = $tol"
end
xs[ic] = VectorInterface.add!!(xs[ic], x, -factor)
Expand Down
4 changes: 2 additions & 2 deletions test/ad/degenerateeigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ end

tol = 2 * N^2 * eps(real(T))
alg = Arnoldi(; tol=tol, krylovdim=2n)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1)
alg_rrule2 = GMRES(; tol=tol, krylovdim=2n, verbosity=-1)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n)
alg_rrule2 = GMRES(; tol=tol, krylovdim=2n)
mat_example1, mat_example_fun1, mat_example_fd, Avec, Bvec, Cvec, xvec, vals, vecs = build_mat_example(A,
B,
C,
Expand Down
14 changes: 10 additions & 4 deletions test/ad/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ end
condA = cond(A)
tol = n * condA * (T <: Real ? eps(T) : 4 * eps(real(T)))
alg = Arnoldi(; tol=tol, krylovdim=n)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n, verbosity=-1)
alg_rrule2 = GMRES(; tol=tol, krylovdim=n + 1, verbosity=-1)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n)
alg_rrule2 = GMRES(; tol=tol, krylovdim=n + 1)
config = Zygote.ZygoteRuleConfig()
@testset for which in whichlist
for alg_rrule in (alg_rrule1, alg_rrule2)
Expand Down Expand Up @@ -269,11 +269,13 @@ end

if T <: Complex
@testset "test warnings and info" begin
alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=-1)
alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=0)
alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=0)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
@test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent()))

alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=1)
alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=0)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
Expand All @@ -282,6 +284,7 @@ end
pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent()))
@test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T)))

alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=1)
alg_rrule = Arnoldi(; tol=tol, krylovdim=n, verbosity=1)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
Expand All @@ -290,11 +293,13 @@ end
pbs = @test_logs (:info,) pb((ZeroTangent(), vecs[1:2], NoTangent()))
@test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T)))

alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=-1)
alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=0)
alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=0)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
@test_logs pb((ZeroTangent(), im .* vecs[1:2] .+ vecs[2:-1:1], NoTangent()))

alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=1)
alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=0)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
Expand All @@ -305,6 +310,7 @@ end
pbs = @test_logs pb((ZeroTangent(), vecs[1:2], NoTangent()))
@test norm(unthunk(pbs[1]), Inf) < condA * sqrt(eps(real(T)))

alg = Arnoldi(; tol=tol, krylovdim=n, verbosity=1)
alg_rrule = GMRES(; tol=tol, krylovdim=n, verbosity=1)
(vals, vecs, info), pb = ChainRulesCore.rrule(config, eigsolve, A, x, howmany,
:LR, alg; alg_rrule=alg_rrule)
Expand Down
14 changes: 10 additions & 4 deletions test/ad/svdsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ end
howmany = 3
tol = 3 * n * condA * (T <: Real ? eps(T) : 4 * eps(real(T)))
alg = GKL(; krylovdim=2n, tol=tol)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1)
alg_rrule2 = GMRES(; tol=tol, krylovdim=3n, verbosity=-1)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=4n)
alg_rrule2 = GMRES(; tol=tol, krylovdim=3n)
config = Zygote.ZygoteRuleConfig()
for alg_rrule in (alg_rrule1, alg_rrule2)
# unfortunately, rrule does not seem type stable for function arguments, because the
Expand Down Expand Up @@ -219,13 +219,15 @@ end
end
if T <: Complex
@testset "test warnings and info" begin
alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=-1)
alg = GKL(; krylovdim=2n, tol=tol, verbosity=0)
alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=0)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
alg_rrule=alg_rrule)
@test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(),
NoTangent()))

alg = GKL(; krylovdim=2n, tol=tol, verbosity=1)
alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=0)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
Expand All @@ -249,6 +251,7 @@ end
(1 - im) .* rvecs[1:2] + rvecs[2:-1:1],
NoTangent()))

alg = GKL(; krylovdim=2n, tol=tol, verbosity=1)
alg_rrule = Arnoldi(; tol=tol, krylovdim=4n, verbosity=1)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
Expand All @@ -272,13 +275,15 @@ end
(1 - im) .* rvecs[1:2] + rvecs[2:-1:1],
NoTangent()))

alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=-1)
alg = GKL(; krylovdim=2n, tol=tol, verbosity=0)
alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=0)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
alg_rrule=alg_rrule)
@test_logs pb((ZeroTangent(), im .* lvecs[1:2] .+ lvecs[2:-1:1], ZeroTangent(),
NoTangent()))

alg = GKL(; krylovdim=2n, tol=tol, verbosity=1)
alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=0)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
Expand All @@ -305,6 +310,7 @@ end
(1 - im) .* rvecs[1:2] + rvecs[2:-1:1],
NoTangent()))

alg = GKL(; krylovdim=2n, tol=tol, verbosity=1)
alg_rrule = GMRES(; tol=tol, krylovdim=3n, verbosity=1)
(vals, lvecs, rvecs, info), pb = ChainRulesCore.rrule(config, svdsolve, A, x,
howmany, :LR, alg;
Expand Down

0 comments on commit bd61266

Please sign in to comment.