Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change rrule verbosity behaviour #106

Merged
merged 8 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
LinearAlgebra = "1"
Logging = "1"
PackageExtensionCompat = "1"
Printf = "1"
Random = "1"
Test = "1"
TestExtras = "0.2,0.3"
VectorInterface = "0.4,0.5"
VectorInterface = "0.5"
Zygote = "0.6"
julia = "1.6"

Expand All @@ -36,9 +37,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Aqua", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"]
test = ["Test", "Aqua", "Logging", "TestExtras", "ChainRulesTestUtils", "ChainRulesCore", "FiniteDifferences", "Zygote"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LinearAlgebra
using VectorInterface

using KrylovKit: apply_normal, apply_adjoint
using KrylovKit: WARN_LEVEL, STARTSTOP_LEVEL, EACHITERATION_LEVEL

include("utilities.jl")
include("linsolve.jl")
Expand Down
22 changes: 13 additions & 9 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
b = (zerovector(v), convert(T, Δλ))
else
vdΔv = inner(v, Δv)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
gauge = abs(imag(vdΔv))
gauge > alg_primal.tol &&
@warn "`eigsolve` cotangent for eigenvector $i is sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -152,9 +152,11 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (y1, y2)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
elseif abs(w[2]) > alg_rrule.tol && alg_rrule.verbosity >= 0
elseif abs(w[2]) > (alg_rrule.tol * norm(w[1])) &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $(w[2])"
end
ws[i] = w[1]
Expand Down Expand Up @@ -185,7 +187,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,

# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = VdΔV[mask] - Diagonal(real(diag(VdΔV)))[mask]
Δgauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -263,7 +265,8 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, conj.(vals) .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end
# cleanup and construct final result by renormalising the eigenvectors and explicitly
Expand All @@ -276,7 +279,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
w, x = Ws[ic]
factor = 1 / x[i]
x[i] = zero(x[i])
if alg_rrule.verbosity >= 0
if alg_primal.verbosity >= WARN_LEVEL
error = max(norm(x, Inf), abs(rvals[ic] - conj(vals[i])))
error > 10 * tol &&
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $error"
Expand Down Expand Up @@ -308,7 +311,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
# components along subspace spanned by current eigenvectors
tol = alg_primal.tol
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -366,7 +369,8 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
return (w′, vals .* x)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"
end

Expand All @@ -380,7 +384,7 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
factor = 1 / x[ic]
x[ic] = zero(x[ic])
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= WARN_LEVEL
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
Expand Down
3 changes: 2 additions & 1 deletion ext/KrylovKitChainRulesCoreExt/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ function make_linsolve_pullback(fᴴ, b, a₀, a₁, alg_rrule, construct∂f, x
a₁)))
∂b, reverse_info = linsolve(fᴴ, x̄, x̄₀, alg_rrule, conj(a₀),
conj(a₁))
if info.converged > 0 && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged > 0 && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem did: normres = $(reverse_info.normres)"
end
x∂b = inner(x, ∂b)
Expand Down
14 changes: 8 additions & 6 deletions ext/KrylovKitChainRulesCoreExt/svdsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
udΔu = inner(u, Δu)
vdΔv = inner(v, Δv)
if (udΔu isa Complex) || (vdΔv isa Complex)
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
gauge = abs(imag(udΔu + vdΔv))
gauge > alg_primal.tol &&
@warn "`svdsolve` cotangents for singular vectors $i are sensitive to gauge choice: (|gauge| = $gauge)"
Expand All @@ -131,7 +131,8 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′)
end
end
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
if info.converged >= i && reverse_info.converged == 0 &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
end
x = VectorInterface.add!!(x, u, Δs / 2)
Expand Down Expand Up @@ -162,7 +163,7 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

tol = alg_primal.tol
if alg_rrule.verbosity >= 0
if alg_rrule.verbosity >= WARN_LEVEL
mask = abs.(vals' .- vals) .< tol
gaugepart = view(aUdΔU, mask) + view(aVdΔV, mask)
gauge = norm(gaugepart, Inf)
Expand Down Expand Up @@ -227,7 +228,8 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
return (x′, y′, vals .* z)
end
end
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
if info.converged >= n && reverse_info.converged < n &&
alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent problem did not converge, whereas the primal singular value problem did"
end

Expand All @@ -236,13 +238,13 @@ function compute_svdsolve_pullback_data(Δvals, Δlvecs, Δrvecs, vals, lvecs, r
for i in 1:n
x, y, z = Ws[i]
_, ic = findmax(abs, z)
if ic != i
if ic != i && alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result"
end
factor = 1 / z[ic]
z[ic] = zero(z[ic])
error = max(norm(z, Inf), abs(rvals[i] - vals[ic]))
if error > 5 * tol && alg_rrule.verbosity >= 0
if error > 10 * tol && alg_primal.verbosity >= WARN_LEVEL
@warn "`svdsolve` cotangent linear problem ($ic) returns unexpected result: error = $error vs tol = $tol"
end
xs[ic] = VectorInterface.add!!(xs[ic], x, -factor)
Expand Down
42 changes: 30 additions & 12 deletions src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@
end
Base.length(r::SplitRange) = r.outerlength

# Algorithm types
include("algorithms.jl")

# Structures to store a list of basis vectors
"""
abstract type Basis{T} end
Expand All @@ -122,14 +119,6 @@
"""
abstract type Basis{T} end

include("orthonormal.jl")

# Dense linear algebra structures and functions used in the algorithms below
include("dense/givens.jl")
include("dense/linalg.jl")
include("dense/packedhessenberg.jl")
include("dense/reflector.jl")

# Simple coordinate basis vector, i.e. a vector of all zeros and a single one on position `k`:
"""
SimpleBasisVector(m, k)
Expand Down Expand Up @@ -164,6 +153,23 @@
# apply operators
include("apply.jl")

# Verbosity levels
const WARN_LEVEL = 1
const STARTSTOP_LEVEL = 2
const EACHITERATION_LEVEL = 3

# Algorithm types
include("algorithms.jl")

# OrthonormalBasis, orthogonalization and orthonormalization methods
include("orthonormal.jl")

# Dense linear algebra structures and functions used in the algorithms below
include("dense/givens.jl")
include("dense/linalg.jl")
include("dense/packedhessenberg.jl")
include("dense/reflector.jl")

# Krylov and related factorizations and their iterators
include("factorizations/krylov.jl")
include("factorizations/lanczos.jl")
Expand Down Expand Up @@ -217,7 +223,19 @@
" iterations and ",
info.numops,
" applications of the linear map;")
return println(io, "norms of residuals are given by $((info.normres...,)).")
return print(io, "norms of residuals are given by ", normres2string(info.normres), ".")

Check warning on line 226 in src/KrylovKit.jl

View check run for this annotation

Codecov / codecov/patch

src/KrylovKit.jl#L226

Added line #L226 was not covered by tests
end

# Convert residual norms into strings for info and warning printing
normres2string(β::Number) = @sprintf("%.2e", β)
function normres2string(β)
s = "("
for i in 1:length(β)
s *= normres2string(β[i])
i < length(β) && (s *= ", ")
end
s *= ")"
return s
end

# vectors with modified inner product
Expand Down
Loading
Loading