Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial lsmr draft #46

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export initialize, initialize!, expand!, shrink!
export ClassicalGramSchmidt, ClassicalGramSchmidt2, ClassicalGramSchmidtIR
export ModifiedGramSchmidt, ModifiedGramSchmidt2, ModifiedGramSchmidtIR
export LanczosIterator, ArnoldiIterator, GKLIterator
export CG, GMRES, Lanczos, Arnoldi, GKL, GolubYe
export CG, GMRES, Lanczos, Arnoldi, GKL, GolubYe, LSMR;
export KrylovDefaults, ClosestTo, EigSorter
export RecursiveVec, InnerProductVec

Expand Down Expand Up @@ -344,6 +344,7 @@ include("eigsolve/svdsolve.jl")
include("linsolve/linsolve.jl")
include("linsolve/cg.jl")
include("linsolve/gmres.jl")
include("linsolve/lsmr.jl")

# exponentiate
include("matrixfun/exponentiate.jl")
Expand Down
44 changes: 44 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,50 @@ GMRES(; krylovdim::Integer = KrylovDefaults.krylovdim,
verbosity::Int = 0) =
GMRES(orth, maxiter, krylovdim, tol, verbosity)

"""
LSMR(; orth = KrylovDefaults.orth,atol = KrylovDefaults.tol,btol = KrylovDefaults.tol,conlim = 1/KrylovDefaults.tol,
maxiter = KrylovDefaults.maxiter,krylovdim = KrylovDefaults.krylovdim,λ = 0.0,verbosity = 0)

Represents the LSMR algorithm, which minimizes ``\\|Ax - b\\|^2 + \\|λx\\|^2`` in the Euclidean norm.
If multiple solutions exists the minimum norm solution is returned.
The method is based on the Golub-Kahan bidiagonalization process. It is
algebraically equivalent to applying MINRES to the normal equations
``(A^*A + λ^2I)x = A^*b``, but has better numerical properties,
especially if ``A`` is ill-conditioned.

- `atol::Number = 1e-6`, `btol::Number = 1e-6`: stopping tolerances. If both are
1.0e-9 (say), the final residual norm should be accurate to about 9 digits.
(The final `x` will usually have fewer correct digits,
depending on `cond(A)` and the size of damp).
- `conlim::Number = 1e8`: stopping tolerance. `lsmr` terminates if an estimate
of `cond(A)` exceeds conlim. For compatible systems Ax = b,
conlim could be as large as 1.0e+12 (say). For least-squares
problems, conlim should be less than 1.0e+8.
Maximum precision can be obtained by setting
- `atol` = `btol` = `conlim` = zero, but the number of iterations
may then be excessive.
"""
struct LSMR{O<:Orthogonalizer,S<:Real} <: KrylovAlgorithm
orth::O
atol::S
btol::S
conlim::S
maxiter::Int
verbosity::Int
λ::S
krylovdim::Int
end
LSMR(; orth = KrylovDefaults.orth,
atol = KrylovDefaults.tol,
btol = KrylovDefaults.tol,
conlim = 1/min(atol,btol),
maxiter = KrylovDefaults.maxiter,
krylovdim = KrylovDefaults.krylovdim,
λ = zero(atol),
verbosity = 0) = LSMR(orth,atol,btol,conlim,maxiter,verbosity,λ,krylovdim)



# TODO
"""
MINRES(; maxiter = KrylovDefaults.maxiter, tol = KrylovDefaults.tol)
Expand Down
7 changes: 4 additions & 3 deletions src/krylov/gkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ function initialize(iter::GKLIterator; verbosity::Int = 0)
end
function initialize!(iter::GKLIterator, state::GKLFactorization; verbosity::Int = 0)
V = state.V
while length(U) > 1
pop!(U)
u = state.U[1];
while length(state.U) > 1
pop!(state.U)
end
V = empty!(state.V)
αs = empty!(state.αs)
βs = empty!(state.βs)

u = mul!(V[1], iter.u₀, 1/norm(iter.u₀))
u = mul!(u, iter.u₀, 1/norm(iter.u₀))
v = iter.operator(u, true)
α = norm(v)
rmul!(v, 1/α)
Expand Down
194 changes: 194 additions & 0 deletions src/linsolve/lsmr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# reference implementation https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl/blob/master/src/lsmr.jl
linsolve(operator, b, alg::LSMR) = linsolve(operator,b,svdfun(operator)(x,true),alg);
function linsolve(operator, b, x, alg::LSMR)
u = axpby!(1,b,-1,svdfun(operator)(x,false))
β = norm(u);

# initialize GKL factorization
iter = GKLIterator(svdfun(operator), u, alg.orth)
fact = initialize(iter; verbosity = alg.verbosity-2)
numops = 2
sizehint!(fact, alg.krylovdim)

T = eltype(fact);
Tr = real(T)
alg.conlim > 0 ? ctol = convert(Tr, inv(alg.conlim)) : ctol = zero(Tr);
istop = 0;

for topit = 1:alg.maxiter# the outermost restart loop
# Initialize variables for 1st iteration.
α = fact.αs[end];
ζbar = α * β
αbar = α
ρ = one(Tr)
ρbar = one(Tr)
cbar = one(Tr)
sbar = zero(Tr)

# Initialize variables for estimation of ||r||.
βdd = β
βd = zero(Tr)
ρdold = one(Tr)
τtildeold = zero(Tr)
θtilde = zero(Tr)
ζ = zero(Tr)
d = zero(Tr)

# Initialize variables for estimation of ||A|| and cond(A).
normA, condA, normx = -one(Tr), -one(Tr), -one(Tr)
normA2 = abs2(α)
maxrbar = zero(Tr)
minrbar = 1e100

# Items for use in stopping rules.
normb = β
normr = β
normAr = α * β

hbar = zero(T)*x;
h = one(T)*fact.V[end];

while length(fact) < alg.krylovdim

β = normres(fact);
fact = expand!(iter, fact)
numops += 2;

v = fact.V[end];
α = fact.αs[end];

# Construct rotation Qhat_{k,2k+1}.
αhat = hypot(αbar, alg.λ)
chat = αbar / αhat
shat = alg.λ / αhat

# Use a plane rotation (Q_i) to turn B_i to R_i.
ρold = ρ
ρ = hypot(αhat, β)
c = αhat / ρ
s = β / ρ
θnew = s * α
αbar = c * α

# Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar.
ρbarold = ρbar
ζold = ζ
θbar = sbar * ρ
ρtemp = cbar * ρ
ρbar = hypot(cbar * ρ, θnew)
cbar = cbar * ρ / ρbar
sbar = θnew / ρbar
ζ = cbar * ζbar
ζbar = - sbar * ζbar

# Update h, h_hat, x.
hbar = axpby!(1,h,-θbar * ρ / (ρold * ρbarold),hbar);
h = axpby!(1,v,-θnew / ρ,h);
x = axpy!(ζ / (ρ * ρbar),hbar,x)

##############################################################################
##
## Estimate of ||r||
##
##############################################################################

# Apply rotation Qhat_{k,2k+1}.
βacute = chat * βdd
βcheck = - shat * βdd

# Apply rotation Q_{k,k+1}.
βhat = c * βacute
βdd = - s * βacute

# Apply rotation Qtilde_{k-1}.
θtildeold = θtilde
ρtildeold = hypot(ρdold, θbar)
ctildeold = ρdold / ρtildeold
stildeold = θbar / ρtildeold
θtilde = stildeold * ρbar
ρdold = ctildeold * ρbar
βd = - stildeold * βd + ctildeold * βhat

τtildeold = (ζold - θtildeold * τtildeold) / ρtildeold
τd = (ζ - θtilde * τtildeold) / ρdold
d += abs2(βcheck)
normr = sqrt(d + abs2(βd - τd) + abs2(βdd))

# Estimate ||A||.
normA2 += abs2(β)
normA = sqrt(normA2)
normA2 += abs2(α)

# Estimate cond(A).
maxrbar = max(maxrbar, ρbarold)
if length(fact) > 1
minrbar = min(minrbar, ρbarold)
end
condA = max(maxrbar, ρtemp) / min(minrbar, ρtemp)


##############################################################################
##
## Test for convergence
##
##############################################################################

# Compute norms for convergence testing.
normAr = abs(ζbar)
normx = norm(x)

# Now use these norms to estimate certain other quantities,
# some of which will be small near a solution.
test1 = normr / normb
test2 = normAr / (normA * normr)
test3 = inv(condA)

t1 = test1 / (one(Tr) + normA * normx / normb)
rtol = alg.btol + alg.atol * normA * normx / normb
# The following tests guard against extremely small values of
# atol, btol or ctol. (The user may have set any or all of
# the parameters atol, btol, conlim to 0.)
# The effect is equivalent to the normAl tests using
# atol = eps, btol = eps, conlim = 1/eps.

if alg.verbosity > 2
msg = "LSMR linsolve in iter $topit; step $(length(fact)-1): "
msg *= "normres = "
msg *= @sprintf("%.12e", normr)
@info msg
end

if 1 + test3 <= 1 istop = 6; break end
if 1 + test2 <= 1 istop = 5; break end
if 1 + t1 <= 1 istop = 4; break end
# Allow for tolerances set by the user.
if test3 <= ctol istop = 3; break end
if test2 <= alg.atol istop = 2; break end
if test1 <= rtol istop = 1; break end
end

u = axpby!(1,b,-1,svdfun(operator)(x,false))

istop != 0 && break;

#restart
β = norm(u);
iter = GKLIterator(svdfun(operator), u, alg.orth);
fact = initialize!(iter,fact);
end

isconv = istop ∉ (0,3,6);
if alg.verbosity > 0 && !isconv
@warn """LSMR linsolve finished without converging after $(alg.maxiter) iterations:
* norm of residual = $(norm(u))
* number of operations = $numops"""
elseif alg.verbosity > 0
if alg.verbosity > 0
@info """LSMR linsolve converged due to istop $(istop):
* norm of residual = $(norm(u))
* number of operations = $numops"""
end
end
return (x, ConvergenceInfo(Int(isconv), u, norm(u), alg.maxiter, numops))

end
29 changes: 29 additions & 0 deletions test/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ end
@test b ≈ (α₀*I+α₁*A)*unwrapvec(x) + unwrapvec(info.residual)
end
end

@testset "full lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (n,n))
v = rand(T,n);
w = rand(T,n);
alg = LSMR(orth = orth, krylovdim = 2*n, maxiter = 1, atol = 10*n*eps(real(T)),btol = 10*n*eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg)
@test info.converged > 0
@test v≈A*unwrapvec(S)+unwrapvec(info.residual)
end
end
end

@testset "iterative lsmr" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset for orth in (cgs2, mgs2, cgsr, mgsr)
A = rand(T, (N,N))
v = rand(T,N);
w = rand(T,N);
alg = LSMR(orth = orth, krylovdim = N, maxiter = 50, atol = 10*N*eps(real(T)),btol = 10*N*eps(real(T)))
S, info = @inferred linsolve(wrapop(A), wrapvec(v),wrapvec(w), alg)
@test info.converged > 0

@test v≈A*unwrapvec(S)+unwrapvec(info.residual)
end
end
end