Skip to content

Commit

Permalink
Change rrule verbosity behaviour (#106)
Browse files Browse the repository at this point in the history
* use MinimalVec from VectorInterface

* change verbosity behaviour

* change verbosity level structure

* make default verbosity warn, disable ad tests until fixed

* update tests accordingly

* bump VectorInterface requirement, fix tests for Julia 1.6

* update rrule verbosity and tests

* (hopefully) final fixes in tests
  • Loading branch information
Jutho authored Jan 14, 2025
1 parent 3304d77 commit 3b2adff
Show file tree
Hide file tree
Showing 40 changed files with 1,160 additions and 845 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
LinearAlgebra = "1"
Logging = "1"
PackageExtensionCompat = "1"
Printf = "1"
Random = "1"
Test = "1"
TestExtras = "0.2,0.3"
VectorInterface = "0.4,0.5"
VectorInterface = "0.5"
Zygote = "0.6"
julia = "1.6"

Expand All @@ -36,9 +37,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Aqua", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"]
test = ["Test", "Aqua", "Logging", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LinearAlgebra
using VectorInterface

using KrylovKit: apply_normal, apply_adjoint
using KrylovKit: WARN_LEVEL, STARTSTOP_LEVEL, EACHITERATION_LEVEL

include("utilities.jl")
include("linsolve.jl")
Expand Down
22 changes: 13 additions & 9 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
b = (zerovector(v), convert(T, Δλ))
else
vdΔv = inner(v, Δv)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
gauge = abs(imag(vdΔv))
gauge > alg_primal.tol &&
@warn "`eigsolve` cotangent for eigenvector $i is sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -152,9 +152,11 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (y1, y2)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
elseif abs(w[2]) > alg_rrule.tol && alg_rrule.verbosity >= 0
elseif abs(w[2]) > (alg_rrule.tol * norm(w[1])) &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $(w[2])"
end
ws[i] = w[1]
Expand Down Expand Up @@ -185,7 +187,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,

# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = VdΔV[mask] - Diagonal(real(diag(VdΔV)))[mask]
Δgauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -263,7 +265,8 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, conj.(vals) .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end
# cleanup and construct final result by renormalising the eigenvectors and explicitly
Expand All @@ -276,7 +279,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
w, x = Ws[ic]
factor = 1 / x[i]
x[i] = zero(x[i])
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= WARN_LEVEL
error = max(norm(x, Inf), abs(rvals[ic] - conj(vals[i])))
error > 10 * tol &&
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $error"
Expand Down Expand Up @@ -308,7 +311,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -366,7 +369,8 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, vals .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end

Expand All @@ -380,7 +384,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
factor = 1 / x[ic]
x[ic] = zero(x[ic])
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
Expand Down
3 changes: 2 additions & 1 deletion ext/KrylovKitChainRulesCoreExt/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ function make_linsolve_pullback(fᴴ, b, a₀, a₁, alg_rrule, construct∂f, x
a₁)))
∂b, reverse_info = linsolve(fᴴ, x̄, x̄₀, alg_rrule, conj(a₀),
conj(a₁))
if info.converged > 0 && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged > 0 && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem did: normres = $(reverse_info.normres)"
end
x∂b = inner(x, ∂b)
Expand Down
14 changes: 8 additions & 6 deletions ext/KrylovKitChainRulesCoreExt/svdsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
udΔu = inner(u, Δu)
vdΔv = inner(v, Δv)
if (udΔu isa Complex) || (vdΔv isa Complex)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
gauge = abs(imag(udΔu + vdΔv))
gauge > alg_primal.tol &&
@warn "`svdsolve` cotangents for singular vectors $i are sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -131,7 +131,8 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
end
x = VectorInterface.add!!(x, u, Δs / 2)
Expand Down Expand Up @@ -162,7 +163,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(vals' .- vals) .< tol
gaugepart = view(aUdΔU, mask) + view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -227,7 +228,8 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′, vals .* z)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent problem did not converge, whereas the primal singular value problem did"
end

Expand All @@ -236,13 +238,13 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
for i in 1:n
x, y, z = Ws[i]
_, ic = findmax(abs, z)
if ic != i
if ic != i && alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result"
end
factor = 1 / z[ic]
z[ic] = zero(z[ic])
error = max(norm(z, Inf), abs(rvals[i] - vals[ic]))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result: error = $error vs tol = $tol"
end
xs[ic] = VectorInterface.add!!(xs[ic], x, -factor)
Expand Down
42 changes: 30 additions & 12 deletions src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ function Base.iterate(r::SplitRange, i=1)
end
Base.length(r::SplitRange) = r.outerlength

# Algorithm types
include("algorithms.jl")

# Structures to store a list of basis vectors
"""
abstract type Basis{T} end
Expand All @@ -122,14 +119,6 @@ See [`OrthonormalBasis`](@ref) for a specific implementation.
"""
abstract type Basis{T} end

include("orthonormal.jl")

# Dense linear algebra structures and functions used in the algorithms below
include("dense/givens.jl")
include("dense/linalg.jl")
include("dense/packedhessenberg.jl")
include("dense/reflector.jl")

# Simple coordinate basis vector, i.e. a vector of all zeros and a single one on position `k`:
"""
SimpleBasisVector(m, k)
Expand Down Expand Up @@ -164,6 +153,23 @@ end
# apply operators
include("apply.jl")

# Verbosity levels
const WARN_LEVEL = 1
const STARTSTOP_LEVEL = 2
const EACHITERATION_LEVEL = 3

# Algorithm types
include("algorithms.jl")

# OrthonormalBasis, orthogonalization and orthonormalization methods
include("orthonormal.jl")

# Dense linear algebra structures and functions used in the algorithms below
include("dense/givens.jl")
include("dense/linalg.jl")
include("dense/packedhessenberg.jl")
include("dense/reflector.jl")

# Krylov and related factorizations and their iterators
include("factorizations/krylov.jl")
include("factorizations/lanczos.jl")
Expand Down Expand Up @@ -217,7 +223,19 @@ function Base.show(io::IO, info::ConvergenceInfo)
" iterations and ",
info.numops,
" applications of the linear map;")
return println(io, "norms of residuals are given by $((info.normres...,)).")
return print(io, "norms of residuals are given by ", normres2string(info.normres), ".")
end

# Convert residual norms into strings for info and warning printing
normres2string::Number) = @sprintf("%.2e", β)
function normres2string(β)
s = "("
for i in 1:length(β)
s *= normres2string(β[i])
i < length(β) && (s *= ", ")
end
s *= ")"
return s
end

# vectors with modified inner product
Expand Down
Loading

0 comments on commit 3b2adff

Please sign in to comment.