Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bring LSMR implementation of jutho/krylovkit.jl#46 up to date #109

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions docs/src/man/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ KrylovKit.MINRES
GMRES
KrylovKit.BiCG
BiCGStab
LSMR
```
## Specific algorithms for generalized eigenvalue problems
```@docs
Expand Down
3 changes: 2 additions & 1 deletion src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export initialize, initialize!, expand!, shrink!
export ClassicalGramSchmidt, ClassicalGramSchmidt2, ClassicalGramSchmidtIR
export ModifiedGramSchmidt, ModifiedGramSchmidt2, ModifiedGramSchmidtIR
export LanczosIterator, ArnoldiIterator, GKLIterator
export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe
export CG, GMRES, BiCGStab, Lanczos, Arnoldi, GKL, GolubYe, LSMR
export KrylovDefaults, EigSorter
export RecursiveVec, InnerProductVec

Expand Down Expand Up @@ -235,6 +235,7 @@ include("linsolve/linsolve.jl")
include("linsolve/cg.jl")
include("linsolve/gmres.jl")
include("linsolve/bicgstab.jl")
include("linsolve/lsmr.jl")

# eigsolve and svdsolve
include("eigsolve/eigsolve.jl")
Expand Down
58 changes: 53 additions & 5 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ will search for the optimal `x` in a Krylov subspace of maximal size `maxiter`,
`norm(A*x - b) < tol`. Default verbosity level `verbosity` is zero, meaning that no output
will be printed.

See also: [`linsolve`](@ref), [`MINRES`](@ref), [`GMRES`](@ref), [`BiCG`](@ref),
See also: [`linsolve`](@ref), [`MINRES`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), [`LSMR`](@ref),
[`BiCGStab`](@ref)
"""
struct CG{S<:Real} <: LinearSolver
Expand Down Expand Up @@ -262,7 +262,7 @@ to as the restart parameter, and `maxiter` is the number of outer iterations, i.
cycles. The total iteration count, i.e. the number of expansion steps, is roughly
`krylovdim` times the number of iterations.

See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref),
See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref), [`LSMR`](@ref),
[`MINRES`](@ref)
"""
struct GMRES{O<:Orthogonalizer,S<:Real} <: LinearSolver
Expand All @@ -281,6 +281,54 @@ function GMRES(;
return GMRES(orth, maxiter, krylovdim, tol, verbosity)
end

"""
LSMR(; orth = KrylovDefaults.orth,atol = KrylovDefaults.tol,btol = KrylovDefaults.tol,conlim = 1/KrylovDefaults.tol,
maxiter = KrylovDefaults.maxiter,krylovdim = KrylovDefaults.krylovdim,λ = 0.0,verbosity = 0)
Comment on lines +284 to +286
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this docstring could use some formatting improvements and a bit of a clean up 😉


Represents the LSMR algorithm, which minimizes ``\\|Ax - b\\|^2 + \\|λx\\|^2`` in the Euclidean norm.
If multiple solutions exists the minimum norm solution is returned.
The method is based on the Golub-Kahan bidiagonalization process. It is
algebraically equivalent to applying MINRES to the normal equations
``(A^*A + λ^2I)x = A^*b``, but has better numerical properties,
especially if ``A`` is ill-conditioned.

- `atol::Number = 1e-6`, `btol::Number = 1e-6`: stopping tolerances. If both are
1.0e-9 (say), the final residual norm should be accurate to about 9 digits.
(The final `x` will usually have fewer correct digits,
depending on `cond(A)` and the size of damp).
- `conlim::Number = 1e8`: stopping tolerance. `lsmr` terminates if an estimate
of `cond(A)` exceeds conlim. For compatible systems Ax = b,
conlim could be as large as 1.0e+12 (say). For least-squares
problems, conlim should be less than 1.0e+8.
Maximum precision can be obtained by setting
- `atol` = `btol` = `conlim` = zero, but the number of iterations
may then be excessive.

See also: [`linsolve`](@ref), [`BiCG`](@ref), [`BiCGStab`](@ref), [`CG`](@ref),
[`MINRES`](@ref), [`GMRES`](@ref)

"""
struct LSMR{O<:Orthogonalizer,S<:Real} <: KrylovAlgorithm
orth::O
atol::S
btol::S
conlim::S
maxiter::Int
verbosity::Int
λ::S
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having unicode as names for fields can turn out to be not entirely convenient, since this means there is no alternative to access that field on a machine/terminal that has no support for this. While I also like to use unicode, I would advise to only do that for internal variable names, or in cases where an alternative alias can be provided.

krylovdim::Int
end
function LSMR(; orth=KrylovDefaults.orth,
atol=KrylovDefaults.tol,
btol=KrylovDefaults.tol,
conlim=1 / min(atol, btol),
maxiter=KrylovDefaults.maxiter,
krylovdim=KrylovDefaults.krylovdim,
λ=zero(atol),
verbosity=0)
return LSMR(orth, atol, btol, conlim, maxiter, verbosity, λ, krylovdim)
end

# TODO
"""
MINRES(; maxiter = KrylovDefaults.maxiter, tol = KrylovDefaults.tol)
Expand All @@ -295,7 +343,7 @@ end
orthogonalizer `orth`. Default verbosity level `verbosity` is zero, meaning that no
output will be printed.

See also: [`linsolve`](@ref), [`CG`](@ref), [`GMRES`](@ref), [`BiCG`](@ref),
See also: [`linsolve`](@ref), [`CG`](@ref), [`GMRES`](@ref), [`BiCG`](@ref), [`LSMR`](@ref),
[`BiCGStab`](@ref)
"""
struct MINRES{S<:Real} <: LinearSolver
Expand All @@ -322,7 +370,7 @@ end
b) < tol`. Default verbosity level `verbosity` is zero, meaning that no output will be
printed.

See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCGStab`](@ref),
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCGStab`](@ref), [`LSMR`](@ref),
[`MINRES`](@ref)
"""
struct BiCG{S<:Real} <: LinearSolver
Expand All @@ -346,7 +394,7 @@ end
of maximal size `maxiter`, or stop when `norm(A*x - b) < tol`. Default verbosity level
`verbosity` is zero, meaning that no output will be printed.

See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCG`](@ref),
See also: [`linsolve`](@ref), [`GMRES`](@ref), [`CG`](@ref), [`BiCG`](@ref), [`LSMR`](@ref),
[`MINRES`](@ref)
"""
struct BiCGStab{S<:Real} <: LinearSolver
Expand Down
2 changes: 1 addition & 1 deletion src/linsolve/linsolve.jl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it would be good to have an additional mention to the docstring that LSMR requires a different interface for the function handles, since it needs both the regular as well as the adjoint action.

Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ efficiently. Check the documentation for more information on the possible values

The final (expert) method, without default values and keyword arguments, is the one that is
finally called, and can also be used directly. Here, one specifies the algorithm explicitly.
Currently, only [`CG`](@ref), [`GMRES`](@ref) and [`BiCGStab`](@ref) are implemented, where
Currently, only [`CG`](@ref), [`GMRES`](@ref), [`BiCGStab`](@ref) and [`LSMR`](@ref) are implemented, where
`CG` is chosen if `isposdef == true` and `GMRES` is chosen otherwise. Note that in standard
`GMRES` terminology, our parameter `krylovdim` is referred to as the *restart* parameter,
and our `maxiter` parameter counts the number of outer iterations, i.e. restart cycles. In
Expand Down
214 changes: 214 additions & 0 deletions src/linsolve/lsmr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# reference implementation https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl/blob/master/src/lsmr.jl
function linsolve(operator, b, alg::LSMR)
return linsolve(operator, b, zerovector(apply_adjoint(operator, b)), alg)

Check warning on line 3 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L2-L3

Added lines #L2 - L3 were not covered by tests
end;
function linsolve(operator, b, x₀, alg::LSMR)
u = add!!(apply_normal(operator, x₀), b, 1, -1)
β = norm(u)

# initialize GKL factorization
iter = GKLIterator(operator, u, alg.orth)
fact = initialize(iter; verbosity=alg.verbosity - 2)
numops = 2
sizehint!(fact, alg.krylovdim)

T = eltype(fact)
Tr = real(T)
alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr)
istop = 0

# TODO: make this an explicit copy that works with the testing datatypes
x = x₀

for topit in 1:(alg.maxiter)# the outermost restart loop
# Initialize variables for 1st iteration.
α = fact.αs[end]
ζbar = α * β
αbar = α
ρ = one(Tr)
ρbar = one(Tr)
cbar = one(Tr)
sbar = zero(Tr)

# Initialize variables for estimation of ||r||.
βdd = β
βd = zero(Tr)
ρdold = one(Tr)
τtildeold = zero(Tr)
θtilde = zero(Tr)
ζ = zero(Tr)
d = zero(Tr)

# Initialize variables for estimation of ||A|| and cond(A).
normA, condA, normx = -one(Tr), -one(Tr), -one(Tr)
normA2 = abs2(α)
maxrbar = zero(Tr)
minrbar = 1e100

# Items for use in stopping rules.
normb = β
normr = β
normAr = α * β

hbar = scale(x, zero(T))
h = scale(fact.V[end], one(T))

while length(fact) < alg.krylovdim
β = normres(fact)
fact = expand!(iter, fact)
numops += 2

v = fact.V[end]
α = fact.αs[end]

# Construct rotation Qhat_{k,2k+1}.
αhat = hypot(αbar, alg.λ)
chat = αbar / αhat
shat = alg.λ / αhat

# Use a plane rotation (Q_i) to turn B_i to R_i.
ρold = ρ
ρ = hypot(αhat, β)
c = αhat / ρ
s = β / ρ
θnew = s * α
αbar = c * α

# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar.
ρbarold = ρbar
ζold = ζ
θbar = sbar * ρ
ρtemp = cbar * ρ
ρbar = hypot(cbar * ρ, θnew)
cbar = cbar * ρ / ρbar
sbar = θnew / ρbar
ζ = cbar * ζbar
ζbar = -sbar * ζbar

# Update h, h_hat, x.
hbar = add!!(hbar, h, 1, -θbar * ρ / (ρold * ρbarold))
h = add!!(h, v, 1, -θnew / ρ)
x = add!!(x, hbar, ζ / (ρ * ρbar), 1)

##############################################################################
##
## Estimate of ||r||
##
##############################################################################

# Apply rotation Qhat_{k,2k+1}.
βacute = chat * βdd
βcheck = -shat * βdd

# Apply rotation Q_{k,k+1}.
βhat = c * βacute
βdd = -s * βacute

# Apply rotation Qtilde_{k-1}.
θtildeold = θtilde
ρtildeold = hypot(ρdold, θbar)
ctildeold = ρdold / ρtildeold
stildeold = θbar / ρtildeold
θtilde = stildeold * ρbar
ρdold = ctildeold * ρbar
βd = -stildeold * βd + ctildeold * βhat

τtildeold = (ζold - θtildeold * τtildeold) / ρtildeold
τd = (ζ - θtilde * τtildeold) / ρdold
d += abs2(βcheck)
normr = sqrt(d + abs2(βd - τd) + abs2(βdd))

# Estimate ||A||.
normA2 += abs2(β)
normA = sqrt(normA2)
normA2 += abs2(α)

# Estimate cond(A).
maxrbar = max(maxrbar, ρbarold)
if length(fact) > 1
minrbar = min(minrbar, ρbarold)
end
condA = max(maxrbar, ρtemp) / min(minrbar, ρtemp)

##############################################################################
##
## Test for convergence
##
##############################################################################

# Compute norms for convergence testing.
normAr = abs(ζbar)
normx = norm(x)

# Now use these norms to estimate certain other quantities,
# some of which will be small near a solution.
test1 = normr / normb
test2 = normAr / (normA * normr)
test3 = inv(condA)

Comment on lines +143 to +148
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the placement of these variables. Is there a reason to define these before the logger, and to then use test1, test2 and test3 below with some unexplained conditions? I think it is a bit more readable to keep that together:
test3 + 1 <= 1 vs inv(condA) + 1 <= 1 seems like it reads better

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these different convergence criteria currently used? If not, I would probably streamline them to be compatible with what we use in the other methods, even if this "reduces" the functionality or options to specify the convergence. I will take a stab at this myself.

t1 = test1 / (one(Tr) + normA * normx / normb)
rtol = alg.btol + alg.atol * normA * normx / normb
# The following tests guard against extremely small values of
# atol, btol or ctol. (The user may have set any or all of
# the parameters atol, btol, conlim to 0.)
# The effect is equivalent to the normAl tests using
# atol = eps, btol = eps, conlim = 1/eps.

if alg.verbosity > 2
msg = "LSMR linsolve in iter $topit; step $(length(fact)-1): "
msg *= "normres = "
msg *= @sprintf("%.12e", normr)
@info msg

Check warning on line 161 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L158-L161

Added lines #L158 - L161 were not covered by tests
end

if 1 + test3 <= 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be re-expressed as test3 <= eps(one(test3)) ? This is somewhat strange code, because you could easily think that this just means test3 <= 0, which would be a different check.

istop = 6
break

Check warning on line 166 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
end
if 1 + test2 <= 1
istop = 5
break

Check warning on line 170 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end
if 1 + t1 <= 1
istop = 4
break
end
# Allow for tolerances set by the user.
if test3 <= ctol
istop = 3
break

Check warning on line 179 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
end
if test2 <= alg.atol
istop = 2
break
end
if test1 <= rtol
istop = 1
break
end
end

u = add!!(apply_normal(operator, x), b, 1, -1)

istop != 0 && break

#restart
β = norm(u)
iter = GKLIterator(operator, u, alg.orth)
fact = initialize!(iter, fact)
end

isconv = istop ∉ (0, 3, 6)
if alg.verbosity > 0 && !isconv
@warn """LSMR linsolve finished without converging after $(alg.maxiter) iterations:

Check warning on line 203 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L203

Added line #L203 was not covered by tests
* norm of residual = $(norm(u))
* number of operations = $numops"""
elseif alg.verbosity > 0
if alg.verbosity > 0
@info """LSMR linsolve converged due to istop $(istop):

Check warning on line 208 in src/linsolve/lsmr.jl

View check run for this annotation

Codecov / codecov/patch

src/linsolve/lsmr.jl#L207-L208

Added lines #L207 - L208 were not covered by tests
* norm of residual = $(norm(u))
* number of operations = $numops"""
end
end
return (x, ConvergenceInfo(Int(isconv), u, norm(u), alg.maxiter, numops))
end
36 changes: 36 additions & 0 deletions test/linsolve.jl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you should probably also test a non-square matrix here, since that is the primary use-case of the method?

Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,42 @@ end
end
end

# Test LSMR complete
@testset "full lsmr ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
(ComplexF64,)
@testset for T in scalartypes
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (n, n))
v = rand(T, n)
w = rand(T, n)
alg = LSMR(; orth=orth, krylovdim=2 * n, maxiter=1, atol=10 * n * eps(real(T)),
btol=10 * n * eps(real(T)))
S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)),
wrapvec(w, Val(mode)), alg)
@test info.converged > 0
@test v ≈ A * unwrapvec(S) + unwrapvec(info.residual)
end
end
end
@testset "iterative lsmr ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
(ComplexF64,)
@testset for T in scalartypes
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (N, N))
v = rand(T, N)
w = rand(T, N)
alg = LSMR(; orth=orth, krylovdim=N, maxiter=50, atol=10 * N * eps(real(T)),
btol=10 * N * eps(real(T)))
S, info = @inferred linsolve(wrapop(A, Val(mode)), wrapvec(v, Val(mode)),
wrapvec(w, Val(mode)), alg)
@test info.converged > 0
@test v ≈ A * unwrapvec(S) + unwrapvec(info.residual)
end
end
end

# Test GMRES complete
@testset "GMRES full factorization ($mode)" for mode in (:vector, :inplace, :outplace)
scalartypes = mode === :vector ? (Float32, Float64, ComplexF32, ComplexF64) :
Expand Down
Loading