Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
Merge pull request #66 from avik-pal/ap/dfsane_batched
Browse files Browse the repository at this point in the history
Batched DFSane
  • Loading branch information
ChrisRackauckas authored Jun 21, 2023
2 parents dc70eb6 + 625b8f3 commit b2a43e0
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.15"
version = "0.1.16"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
176 changes: 134 additions & 42 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
```julia
SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
```
SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
batched::Bool = false,
max_inner_iterations::Int = 1000)
A low-overhead implementation of the df-sane method for solving large-scale nonlinear
systems of equations. For in depth information about all the parameters and the algorithm,
Expand Down Expand Up @@ -39,8 +42,16 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
``f_1=||F(x_1)||^{nexp}``, `k` is the iteration number, `x` is the current `x`-value and
`F` the current residual. Should satisfy ``η_k > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
``||F||^2 / k^2``.
- `termination_condition`: a `NLSolveTerminationCondition` that determines when the solver
should terminate. Defaults to `NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing, reltol = nothing)`.
- `batched`: if `true`, the algorithm will use a batched version of the algorithm that treats each
column of `x` as a separate problem. This can be useful nonlinear problems involing neural
networks. Defaults to `false`.
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
algorithm. Used exclusively in `batched` mode. Defaults to `1000`.
"""
struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm
σ_min::T
σ_max::T
σ_1::T
Expand All @@ -50,106 +61,187 @@ struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
τ_max::T
nexp::Int
η_strategy::Function
termination_condition::TC
max_inner_iterations::Int

function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
new{typeof(σ_min)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy)
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
batched::Bool = false,
max_inner_iterations = 1000)
return new{batched, typeof(σ_min), typeof(termination_condition)}(σ_min,
σ_max,
σ_1,
M,
γ,
τ_min,
τ_max,
nexp,
η_strategy,
termination_condition,
max_inner_iterations)
end
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane,
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched},
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
kwargs...)
kwargs...) where {batched}
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)

if batched
batch_size = size(x, 2)
end

T = eltype(x)
σ_min = float(alg.σ_min)
σ_max = float(alg.σ_max)
σ_k = float(alg.σ_1)
σ_k = batched ? fill(float(alg.σ_1), 1, batch_size) : float(alg.σ_1)

M = alg.M
γ = float(alg.γ)
τ_min = float(alg.τ_min)
τ_max = float(alg.τ_max)
nexp = alg.nexp
η_strategy = alg.η_strategy

batched && @assert ndims(x)==2 "Batched SimpleDFSane only supports 2D arrays"

if SciMLBase.isinplace(prob)
error("SimpleDFSane currently only supports out-of-place nonlinear problems")
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
(tc.abstol !== nothing ? tc.abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
rtol = reltol !== nothing ? reltol :
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
error("SimpleDFSane currently doesn't support SAFE_BEST termination modes")
end

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
termination_condition = tc(storage)

function ff(x)
F = f(x)
f_k = norm(F)^nexp
f_k = if batched
sum(abs2, F; dims = 1) .^ (nexp / 2)
else
norm(F)^nexp
end
return f_k, F
end

function generate_history(f_k, M)
if batched
history = similar(f_k, (M, length(f_k)))
history .= reshape(f_k, 1, :)
return history
else
return fill(f_k, M)
end
end

f_k, F_k = ff(x)
α_1 = convert(T, 1.0)
f_1 = f_k
history_f_k = fill(f_k, M)
history_f_k = generate_history(f_k, M)

for k in 1:maxiters
iszero(F_k) &&
return SciMLBase.build_solution(prob, alg, x, F_k;
retcode = ReturnCode.Success)

# Spectral parameter range check
if abs(σ_k) > σ_max
σ_k = sign(σ_k) * σ_max
elseif abs(σ_k) < σ_min
σ_k = sign(σ_k) * σ_min
if batched
@. σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
else
σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
end

# Line search direction
d = -σ_k * F_k
d = -σ_k .* F_k

η = η_strategy(f_1, k, x, F_k)
= maximum(history_f_k)
= batched ? maximum(history_f_k; dims = 1) : maximum(history_f_k)
α_p = α_1
α_m = α_1
x_new = x + α_p * d
x_new = @. x + α_p * d

f_new, F_new = ff(x_new)

inner_iterations = 0
while true
if f_new + η - γ * α_p^2 * f_k
break
inner_iterations += 1

if batched
criteria = @.+ η - γ * α_p^2 * f_k
# NOTE: This is simply a heuristic, ideally we check using `all` but that is
# typically very expensive for large problems
(sum(f_new .≤ criteria) batch_size ÷ 2) && break
else
criteria =+ η - γ * α_p^2 * f_k
f_new criteria && break
end

α_tp = α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
x_new = x - α_m * d
α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
x_new = @. x - α_m * d
f_new, F_new = ff(x_new)

if f_new + η - γ * α_m^2 * f_k
break
if batched
# NOTE: This is simply a heuristic, ideally we check using `all` but that is
# typically very expensive for large problems
(sum(f_new .≤ criteria) batch_size ÷ 2) && break
else
f_new criteria && break
end

α_tm = α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
α_p = min(τ_max * α_p, max(α_tp, τ_min * α_p))
α_m = min(τ_max * α_m, max(α_tm, τ_min * α_m))
x_new = x + α_p * d
α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
α_p = @. clamp(α_tp, τ_min * α_p, τ_max * α_p)
α_m = @. clamp(α_tm, τ_min * α_m, τ_max * α_m)
x_new = @. x + α_p * d
f_new, F_new = ff(x_new)

# NOTE: The original algorithm runs till either condition is satisfied, however,
# for most batched problems like neural networks we only care about
# approximate convergence
batched && (inner_iterations alg.max_inner_iterations) && break
end

if isapprox(x_new, x, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, x_new, F_new;
if termination_condition(F_new, x_new, x, atol, rtol)
return SciMLBase.build_solution(prob,
alg,
x_new,
F_new;
retcode = ReturnCode.Success)
end

# Update spectral parameter
s_k = x_new - x
y_k = F_new - F_k
σ_k = (s_k' * s_k) / (s_k' * y_k)
s_k = @. x_new - x
y_k = @. F_new - F_k

if batched
σ_k = sum(abs2, s_k; dims = 1) ./ (sum(s_k .* y_k; dims = 1) .+ T(1e-5))
else
σ_k = (s_k' * s_k) / (s_k' * y_k)
end

# Take step
x = x_new
F_k = F_new
f_k = f_new

# Store function value
history_f_k[k % M + 1] = f_new
if batched
history_f_k[k % M + 1, :] .= vec(f_new)
else
history_f_k[k % M + 1] = f_new
end
end
return SciMLBase.build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters)
end
12 changes: 9 additions & 3 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const BATCHED_BROYDEN_SOLVERS = Broyden[]
const BROYDEN_SOLVERS = Broyden[]
const BATCHED_LBROYDEN_SOLVERS = LBroyden[]
const LBROYDEN_SOLVERS = LBroyden[]
const BATCHED_DFSANE_SOLVERS = SimpleDFSane[]
const DFSANE_SOLVERS = SimpleDFSane[]

for mode in instances(NLSolveTerminationMode.T)
if mode
Expand All @@ -23,6 +25,8 @@ for mode in instances(NLSolveTerminationMode.T)
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
push!(LBROYDEN_SOLVERS, LBroyden(; batched = false, termination_condition))
push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition))
push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition))
push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition))
end

# SimpleNewtonRaphson
Expand Down Expand Up @@ -484,11 +488,13 @@ sol = solve(probN, Broyden(batched = true))

@test abs.(sol.u) sqrt.(p)

for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS...)
sol = solve(probN, alg)
for alg in (BATCHED_BROYDEN_SOLVERS...,
BATCHED_LBROYDEN_SOLVERS...,
BATCHED_DFSANE_SOLVERS...)
sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3)

@test sol.retcode == ReturnCode.Success
@test abs.(sol.u) sqrt.(p)
@test abs.(sol.u)sqrt.(p) atol=1e-3 rtol=1e-3
end

## User specified Jacobian
Expand Down

0 comments on commit b2a43e0

Please sign in to comment.