-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bring LSMR implementation of jutho/krylovkit.jl#46 up to date #109
Changes from 5 commits
448acfa
15b163d
794d556
039eb02
8597e2f
31c4e37
c60d53c
0a7c1fe
0dc1689
ca71deb
0b4de03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,7 +231,7 @@ will search for the optimal `x` in a Krylov subspace of maximal size `maxiter`, | |
`norm(A*x - b) < tol`. Default verbosity level `verbosity` is zero, meaning that no output | ||
will be printed. | ||
|
||
See also: [`linsolve`](@ref), [`MINRES`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), | ||
See also: [`linsolve`](@ref), [`MINRES`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), [`LSMR`](@ref), | ||
[`BiCGStab`](@ref) | ||
""" | ||
struct CG{S<:Real} <: LinearSolver | ||
|
@@ -262,7 +262,7 @@ to as the restart parameter, and `maxiter` is the number of outer iterations, i. | |
cycles. The total iteration count, i.e. the number of expansion steps, is roughly | ||
`krylovdim` times the number of iterations. | ||
|
||
See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref), | ||
See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref), [`LSMR`](@ref), | ||
[`MINRES`](@ref) | ||
""" | ||
struct GMRES{O<:Orthogonalizer,S<:Real} <: LinearSolver | ||
|
@@ -281,6 +281,54 @@ function GMRES(; | |
return GMRES(orth, maxiter, krylovdim, tol, verbosity) | ||
end | ||
|
||
""" | ||
LSMR(; orth = KrylovDefaults.orth,atol = KrylovDefaults.tol,btol = KrylovDefaults.tol,conlim = 1/KrylovDefaults.tol, | ||
maxiter = KrylovDefaults.maxiter,krylovdim = KrylovDefaults.krylovdim,λ = 0.0,verbosity = 0) | ||
|
||
Represents the LSMR algorithm, which minimizes ``\\|Ax - b\\|^2 + \\|λx\\|^2`` in the Euclidean norm. | ||
If multiple solutions exists the minimum norm solution is returned. | ||
The method is based on the Golub-Kahan bidiagonalization process. It is | ||
algebraically equivalent to applying MINRES to the normal equations | ||
``(A^*A + λ^2I)x = A^*b``, but has better numerical properties, | ||
especially if ``A`` is ill-conditioned. | ||
|
||
- `atol::Number = 1e-6`, `btol::Number = 1e-6`: stopping tolerances. If both are | ||
1.0e-9 (say), the final residual norm should be accurate to about 9 digits. | ||
(The final `x` will usually have fewer correct digits, | ||
depending on `cond(A)` and the size of damp). | ||
- `conlim::Number = 1e8`: stopping tolerance. `lsmr` terminates if an estimate | ||
of `cond(A)` exceeds conlim. For compatible systems Ax = b, | ||
conlim could be as large as 1.0e+12 (say). For least-squares | ||
problems, conlim should be less than 1.0e+8. | ||
Maximum precision can be obtained by setting | ||
- `atol` = `btol` = `conlim` = zero, but the number of iterations | ||
may then be excessive. | ||
|
||
See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref), | ||
[`MINRES`](@ref), [`GMRES`](@ref) | ||
|
||
""" | ||
struct LSMR{O<:Orthogonalizer,S<:Real} <: KrylovAlgorithm | ||
orth::O | ||
atol::S | ||
btol::S | ||
conlim::S | ||
maxiter::Int | ||
verbosity::Int | ||
λ::S | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having unicode as names for fields can turn out to be not entirely convenient, since this means there is no alternative to access that field on a machine/terminal that has no support for this. While I also like to use unicode, I would advise to only do that for internal variable names, or in cases where an alternative alias can be provided. |
||
krylovdim::Int | ||
end | ||
function LSMR(; orth=KrylovDefaults.orth, | ||
atol=KrylovDefaults.tol, | ||
btol=KrylovDefaults.tol, | ||
conlim=1 / min(atol, btol), | ||
maxiter=KrylovDefaults.maxiter, | ||
krylovdim=KrylovDefaults.krylovdim, | ||
λ=zero(atol), | ||
verbosity=0) | ||
return LSMR(orth, atol, btol, conlim, maxiter, verbosity, λ, krylovdim) | ||
end | ||
|
||
# TODO | ||
""" | ||
MINRES(; maxiter = KrylovDefaults.maxiter, tol = KrylovDefaults.tol) | ||
|
@@ -295,7 +343,7 @@ end | |
orthogonalizer `orth`. Default verbosity level `verbosity` is zero, meaning that no | ||
output will be printed. | ||
|
||
See also: [`linsolve`](@ref), [`CG`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), | ||
See also: [`linsolve`](@ref), [`CG`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), [`LSMR`](@ref), | ||
[`BiCGStab`](@ref) | ||
""" | ||
struct MINRES{S<:Real} <: LinearSolver | ||
|
@@ -322,7 +370,7 @@ end | |
b) < tol`. Default verbosity level `verbosity` is zero, meaning that no output will be | ||
printed. | ||
|
||
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCGStab`](@ref), | ||
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCGStab`](@ref), [`LSMR`](@ref), | ||
[`MINRES`](@ref) | ||
""" | ||
struct BiCG{S<:Real} <: LinearSolver | ||
|
@@ -346,7 +394,7 @@ end | |
of maximal size `maxiter`, or stop when `norm(A*x - b) < tol`. Default verbosity level | ||
`verbosity` is zero, meaning that no output will be printed. | ||
|
||
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCG`](@ref), | ||
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCG`](@ref), [`LSMR`](@ref), | ||
[`MINRES`](@ref) | ||
""" | ||
struct BiCGStab{S<:Real} <: LinearSolver | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably it would be good to have an additional mention to the docstring that LSMR requires a different interface for the function handles, since it needs both the regular as well as the adjoint action. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# reference implementation https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl/blob/master/src/lsmr.jl | ||
function linsolve(operator, b, alg::LSMR) | ||
return linsolve(operator, b, zerovector(apply_adjoint(operator, b)), alg) | ||
end; | ||
function linsolve(operator, b, x₀, alg::LSMR) | ||
u = add!!(apply_normal(operator, x₀), b, 1, -1) | ||
β = norm(u) | ||
|
||
# initialize GKL factorization | ||
iter = GKLIterator(operator, u, alg.orth) | ||
fact = initialize(iter; verbosity=alg.verbosity - 2) | ||
numops = 2 | ||
sizehint!(fact, alg.krylovdim) | ||
|
||
T = eltype(fact) | ||
Tr = real(T) | ||
alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr) | ||
istop = 0 | ||
|
||
# TODO: make this an explicit copy that works with the testing datatypes | ||
x = x₀ | ||
|
||
for topit in 1:(alg.maxiter)# the outermost restart loop | ||
# Initialize variables for 1st iteration. | ||
α = fact.αs[end] | ||
ζbar = α * β | ||
αbar = α | ||
ρ = one(Tr) | ||
ρbar = one(Tr) | ||
cbar = one(Tr) | ||
sbar = zero(Tr) | ||
|
||
# Initialize variables for estimation of ||r||. | ||
βdd = β | ||
βd = zero(Tr) | ||
ρdold = one(Tr) | ||
τtildeold = zero(Tr) | ||
θtilde = zero(Tr) | ||
ζ = zero(Tr) | ||
d = zero(Tr) | ||
|
||
# Initialize variables for estimation of ||A|| and cond(A). | ||
normA, condA, normx = -one(Tr), -one(Tr), -one(Tr) | ||
normA2 = abs2(α) | ||
maxrbar = zero(Tr) | ||
minrbar = 1e100 | ||
|
||
# Items for use in stopping rules. | ||
normb = β | ||
normr = β | ||
normAr = α * β | ||
|
||
hbar = scale(x, zero(T)) | ||
h = scale(fact.V[end], one(T)) | ||
|
||
while length(fact) < alg.krylovdim | ||
β = normres(fact) | ||
fact = expand!(iter, fact) | ||
numops += 2 | ||
|
||
v = fact.V[end] | ||
α = fact.αs[end] | ||
|
||
# Construct rotation Qhat_{k,2k+1}. | ||
αhat = hypot(αbar, alg.λ) | ||
chat = αbar / αhat | ||
shat = alg.λ / αhat | ||
|
||
# Use a plane rotation (Q_i) to turn B_i to R_i. | ||
ρold = ρ | ||
ρ = hypot(αhat, β) | ||
c = αhat / ρ | ||
s = β / ρ | ||
θnew = s * α | ||
αbar = c * α | ||
|
||
# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar. | ||
ρbarold = ρbar | ||
ζold = ζ | ||
θbar = sbar * ρ | ||
ρtemp = cbar * ρ | ||
ρbar = hypot(cbar * ρ, θnew) | ||
cbar = cbar * ρ / ρbar | ||
sbar = θnew / ρbar | ||
ζ = cbar * ζbar | ||
ζbar = -sbar * ζbar | ||
|
||
# Update h, h_hat, x. | ||
hbar = add!!(hbar, h, 1, -θbar * ρ / (ρold * ρbarold)) | ||
h = add!!(h, v, 1, -θnew / ρ) | ||
x = add!!(x, hbar, ζ / (ρ * ρbar), 1) | ||
|
||
############################################################################## | ||
## | ||
## Estimate of ||r|| | ||
## | ||
############################################################################## | ||
|
||
# Apply rotation Qhat_{k,2k+1}. | ||
βacute = chat * βdd | ||
βcheck = -shat * βdd | ||
|
||
# Apply rotation Q_{k,k+1}. | ||
βhat = c * βacute | ||
βdd = -s * βacute | ||
|
||
# Apply rotation Qtilde_{k-1}. | ||
θtildeold = θtilde | ||
ρtildeold = hypot(ρdold, θbar) | ||
ctildeold = ρdold / ρtildeold | ||
stildeold = θbar / ρtildeold | ||
θtilde = stildeold * ρbar | ||
ρdold = ctildeold * ρbar | ||
βd = -stildeold * βd + ctildeold * βhat | ||
|
||
τtildeold = (ζold - θtildeold * τtildeold) / ρtildeold | ||
τd = (ζ - θtilde * τtildeold) / ρdold | ||
d += abs2(βcheck) | ||
normr = sqrt(d + abs2(βd - τd) + abs2(βdd)) | ||
|
||
# Estimate ||A||. | ||
normA2 += abs2(β) | ||
normA = sqrt(normA2) | ||
normA2 += abs2(α) | ||
|
||
# Estimate cond(A). | ||
maxrbar = max(maxrbar, ρbarold) | ||
if length(fact) > 1 | ||
minrbar = min(minrbar, ρbarold) | ||
end | ||
condA = max(maxrbar, ρtemp) / min(minrbar, ρtemp) | ||
|
||
############################################################################## | ||
## | ||
## Test for convergence | ||
## | ||
############################################################################## | ||
|
||
# Compute norms for convergence testing. | ||
normAr = abs(ζbar) | ||
normx = norm(x) | ||
|
||
# Now use these norms to estimate certain other quantities, | ||
# some of which will be small near a solution. | ||
test1 = normr / normb | ||
test2 = normAr / (normA * normr) | ||
test3 = inv(condA) | ||
|
||
Comment on lines
+143
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused by the placement of these variables. Is there a reason to define these before the logger, and to then use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these different convergence criteria currently used? If not, I would probably streamline them to be compatible with what we use in the other methods, even if this "reduces" the functionality or options to specify the convergence. I will take a stab at this myself. |
||
t1 = test1 / (one(Tr) + normA * normx / normb) | ||
rtol = alg.btol + alg.atol * normA * normx / normb | ||
# The following tests guard against extremely small values of | ||
# atol, btol or ctol. (The user may have set any or all of | ||
# the parameters atol, btol, conlim to 0.) | ||
# The effect is equivalent to the normAl tests using | ||
# atol = eps, btol = eps, conlim = 1/eps. | ||
|
||
if alg.verbosity > 2 | ||
msg = "LSMR linsolve in iter $topit; step $(length(fact)-1): " | ||
msg *= "normres = " | ||
msg *= @sprintf("%.12e", normr) | ||
@info msg | ||
end | ||
|
||
if 1 + test3 <= 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be re-expressed as |
||
istop = 6 | ||
break | ||
end | ||
if 1 + test2 <= 1 | ||
istop = 5 | ||
break | ||
end | ||
if 1 + t1 <= 1 | ||
istop = 4 | ||
break | ||
end | ||
# Allow for tolerances set by the user. | ||
if test3 <= ctol | ||
istop = 3 | ||
break | ||
end | ||
if test2 <= alg.atol | ||
istop = 2 | ||
break | ||
end | ||
if test1 <= rtol | ||
istop = 1 | ||
break | ||
end | ||
end | ||
|
||
u = add!!(apply_normal(operator, x), b, 1, -1) | ||
|
||
istop != 0 && break | ||
|
||
#restart | ||
β = norm(u) | ||
iter = GKLIterator(operator, u, alg.orth) | ||
fact = initialize!(iter, fact) | ||
end | ||
|
||
isconv = istop ∉ (0, 3, 6) | ||
if alg.verbosity > 0 && !isconv | ||
@warn """LSMR linsolve finished without converging after $(alg.maxiter) iterations: | ||
* norm of residual = $(norm(u)) | ||
* number of operations = $numops""" | ||
elseif alg.verbosity > 0 | ||
if alg.verbosity > 0 | ||
@info """LSMR linsolve converged due to istop $(istop): | ||
* norm of residual = $(norm(u)) | ||
* number of operations = $numops""" | ||
end | ||
end | ||
return (x, ConvergenceInfo(Int(isconv), u, norm(u), alg.maxiter, numops)) | ||
end |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess you should probably also test a non-square matrix here, since that is the primary use-case of the method? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this docstring could use some formatting improvements and a bit of a clean up 😉