diff --git a/Project.toml b/Project.toml index d72f9d0..1209349 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.20" +version = "0.1.21" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -9,10 +9,12 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] diff --git a/ext/SimpleNonlinearSolveNNlibExt.jl b/ext/SimpleNonlinearSolveNNlibExt.jl index 5b06530..cfc3bc7 100644 --- a/ext/SimpleNonlinearSolveNNlibExt.jl +++ b/ext/SimpleNonlinearSolveNNlibExt.jl @@ -4,10 +4,7 @@ using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, Sc import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace -function __init__() - SimpleNonlinearSolve.NNlibExtLoaded[] = true - return -end +SimpleNonlinearSolve.extension_loaded(::Val{NNlib}) = true @views function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedBroyden; diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 5c500d5..10252ae 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -2,9 +2,10 @@ module SimpleNonlinearSolve using Reexport using FiniteDiff, ForwardDiff -using ForwardDiff: Dual +using ForwardDiff: Dual, Partials, Tag using StaticArraysCore using LinearAlgebra +using LinearSolve import ArrayInterface using DiffEqBase @@ -15,7 +16,7 @@ function __init__() @require_extensions end -const NNlibExtLoaded = Ref{Bool}(false) +extension_loaded(::Val) = false abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @@ -39,6 +40,7 @@ include("ad.jl") include("halley.jl") include("alefeld.jl") include("itp.jl") +include("jnfk.jl") # Batched Solver Support include("batched/utils.jl") @@ -77,7 +79,7 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, - Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP + Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleJFNK export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane end # module diff --git a/src/broyden.jl b/src/broyden.jl index 6c5c3ce..df0b2d4 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -22,7 +22,7 @@ function Broyden(; batched = false, abstol = nothing, reltol = nothing)) if batched - @assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden." + @assert extension_loaded(Val(:NNlib)) "Please install and load `NNlib.jl` to use batched Broyden." return BatchedBroyden(termination_condition) end return Broyden(termination_condition) diff --git a/src/jnfk.jl b/src/jnfk.jl new file mode 100644 index 0000000..b1e4761 --- /dev/null +++ b/src/jnfk.jl @@ -0,0 +1,69 @@ +struct SimpleJNFKJacVecTag end + +function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T} + v_ = reshape(v, axes(x)) + y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) +end +jvp_forwarddiff!(r, f, x, v) = copyto!(r, jvp_forwarddiff(f, x, v)) + +struct JacVecOperator{F, X} + f::F + x::X +end + +(jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v) +(jvp::JacVecOperator)(r, v, _, _) = jvp_forwarddiff!(r, jvp.f, jvp.x, v) + +""" + SimpleJNFK(; batched::Bool = false) + +A low overhead Jacobian-free Newton-Krylov method. This method internally uses `GMRES` to +avoid computing the Jacobian Matrix. + +!!! warning + + JNFK doesn't work well without preconditioning, which is currently not supported. We + recommend using `NewtonRaphson(linsolve = KrylovJL_GMRES())` for preconditioning + support. +""" +struct SimpleJFNK end + +function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...) + iip = SciMLBase.isinplace(prob) + @assert !iip "SimpleJFNK does not support inplace problems" + + f = Base.Fix2(prob.f, prob.p) + x = float(prob.u0) + fx = f(x) + T = typeof(x) + + iszero(fx) && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + + atol = abstol !== nothing ? abstol : + real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) + rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) + + op = FunctionOperator(JacVecOperator(f, x), x) + linprob = LinearProblem(op, vec(fx)) + lincache = init(linprob, KrylovJL_GMRES(); abstol = atol, reltol = rtol, maxiters, + linsolve_kwargs...) + + for i in 1:maxiters + linsol = solve!(lincache) + axpy!(-1, linsol.u, x) + lincache = linsol.cache + + fx = f(x) + + norm(fx, Inf) ≤ atol && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + + lincache.b = vec(fx) + lincache.A = FunctionOperator(JacVecOperator(f, x), x) + end + + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) +end diff --git a/src/raphson.jl b/src/raphson.jl index 48b8f75..0d9fa80 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -42,7 +42,6 @@ function SimpleNewtonRaphson(; batched = false, throw(ArgumentError("`termination_condition` is currently only supported for batched problems")) end if batched - # @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson." termination_condition = ismissing(termination_condition) ? NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; abstol = nothing,