From 59d69cd7a2c291939381ff78ad29765ef5448b73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Oct 2023 21:17:46 -0400 Subject: [PATCH 1/2] Prototype SimpleJNFK --- Project.toml | 4 ++- src/SimpleNonlinearSolve.jl | 6 ++-- src/jnfk.jl | 55 +++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 src/jnfk.jl diff --git a/Project.toml b/Project.toml index d72f9d0..1209349 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.20" +version = "0.1.21" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -9,10 +9,12 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 5c500d5..1360811 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -2,9 +2,10 @@ module SimpleNonlinearSolve using Reexport using FiniteDiff, ForwardDiff -using ForwardDiff: Dual +using ForwardDiff: Dual, Partials, Tag using StaticArraysCore using LinearAlgebra +using LinearSolve import ArrayInterface using DiffEqBase @@ -39,6 +40,7 @@ include("ad.jl") include("halley.jl") include("alefeld.jl") include("itp.jl") +include("jnfk.jl") # Batched Solver Support include("batched/utils.jl") @@ -77,7 +79,7 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, - Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP + Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleJFNK export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane end # module diff --git a/src/jnfk.jl b/src/jnfk.jl new file mode 100644 index 0000000..02e1347 --- /dev/null +++ b/src/jnfk.jl @@ -0,0 +1,55 @@ +struct SimpleJNFKJacVecTag end + +function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T} + v_ = reshape(v, axes(x)) + y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) +end + +struct JacVecOperator{F, X} + f::F + x::X +end + +(jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v) + +""" + SimpleJNFK() + +""" +struct SimpleJFNK end + +function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...; + abstol = nothing, reltol= nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...) + iip = SciMLBase.isinplace(prob) + @assert !iip "SimpleJFNK does not support inplace problems" + + f = Base.Fix2(prob.f, prob.p) + x = float(prob.u0) + fx = f(x) + T = typeof(x) + + atol = abstol !== nothing ? abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) + rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + + op = FunctionOperator(JacVecOperator(f, x), x) + linprob = LinearProblem(op, -fx) + lincache = init(linprob, SimpleGMRES(); abstol, reltol, maxiters, linsolve_kwargs...) + + for i in 1:maxiters + iszero(fx) && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + + linsol = solve!(lincache) + x .-= linsol.u + lincache = linsol.cache + + # FIXME: not nothing + if isapprox(x, nothing; atol, rtol) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + end + end + + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) +end From 98580953dd23d0db7320ea904c18889a7c741722 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Oct 2023 17:06:48 -0400 Subject: [PATCH 2/2] SimpleJNFK working version --- ext/SimpleNonlinearSolveNNlibExt.jl | 5 +--- src/SimpleNonlinearSolve.jl | 2 +- src/broyden.jl | 2 +- src/jnfk.jl | 36 ++++++++++++++++++++--------- src/raphson.jl | 1 - 5 files changed, 28 insertions(+), 18 deletions(-) diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index 5b06530..cfc3bc7 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -4,10 +4,7 @@ using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, Sc import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace -function __init__() - SimpleNonlinearSolve.NNlibExtLoaded[] = true - return -end +SimpleNonlinearSolve.extension_loaded(::Val{NNlib}) = true @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 1360811..10252ae 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -16,7 +16,7 @@ function __init__() @require_extensions end -const NNlibExtLoaded = Ref{Bool}(false) +extension_loaded(::Val) = false abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end diff --git a/src/broyden.jl b/src/broyden.jl index 6c5c3ce..df0b2d4 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -22,7 +22,7 @@ function Broyden(; batched = false, abstol = nothing, reltol = nothing)) if batched - @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + @assert extension_loaded(Val(:NNlib)) "Please install and load `NNlib.jl` to use batched Broyden." return BatchedBroyden(termination_condition) end return Broyden(termination_condition) diff --git a/src/jnfk.jl b/src/jnfk.jl index 02e1347..b1e4761 100644 --- a/src/jnfk.jl +++ b/src/jnfk.jl @@ -5,6 +5,7 @@ function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T} y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_))) return vec(ForwardDiff.partials.(vec(f(y)), 1)) end +jvp_forwarddiff!(r, f, x, v) = copyto!(r, jvp_forwarddiff(f, x, v)) struct JacVecOperator{F, X} f::F @@ -12,15 +13,24 @@ struct JacVecOperator{F, X} end (jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v) +(jvp::JacVecOperator)(r, v, _, _) = jvp_forwarddiff!(r, jvp.f, jvp.x, v) """ - SimpleJNFK() + SimpleJNFK(; batched::Bool = false) +A low overhead Jacobian-free Newton-Krylov method. This method internally uses `GMRES` to +avoid computing the Jacobian Matrix. + +!!! warning + + JNFK doesn't work well without preconditioning, which is currently not supported. We + recommend using `NewtonRaphson(linsolve = KrylovJL_GMRES())` for preconditioning + support. """ struct SimpleJFNK end function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...; - abstol = nothing, reltol= nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...) + abstol = nothing, reltol = nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...) iip = SciMLBase.isinplace(prob) @assert !iip "SimpleJFNK does not support inplace problems" @@ -29,26 +39,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...; fx = f(x) T = typeof(x) + iszero(fx) && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + atol = abstol !== nothing ? abstol : real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) op = FunctionOperator(JacVecOperator(f, x), x) - linprob = LinearProblem(op, -fx) - lincache = init(linprob, SimpleGMRES(); abstol, reltol, maxiters, linsolve_kwargs...) + linprob = LinearProblem(op, vec(fx)) + lincache = init(linprob, KrylovJL_GMRES(); abstol = atol, reltol = rtol, maxiters, + linsolve_kwargs...) for i in 1:maxiters - iszero(fx) && - return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - linsol = solve!(lincache) - x .-= linsol.u + axpy!(-1, linsol.u, x) lincache = linsol.cache - # FIXME: not nothing - if isapprox(x, nothing; atol, rtol) + fx = f(x) + + norm(fx, Inf) ≤ atol && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - end + + lincache.b = vec(fx) + lincache.A = FunctionOperator(JacVecOperator(f, x), x) end return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) diff --git a/src/raphson.jl b/src/raphson.jl index 48b8f75..0d9fa80 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -42,7 +42,6 @@ function SimpleNewtonRaphson(; batched = false, throw(ArgumentError("`termination_condition` is currently only supported for batched problems")) end if batched - # @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." termination_condition = ismissing(termination_condition) ? NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; abstol = nothing,