Skip to content

Commit

Permalink
Compute distribution and ELBO in callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jan 3, 2025
1 parent 9c9fe71 commit e02ad8b
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/Pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ default_ad() = ADTypes.AutoForwardDiff()

include("transducers.jl")
include("woodbury.jl")
include("optimize.jl")
include("lbfgs.jl")
include("optimize.jl")
include("mvnormal.jl")
include("elbo.jl")
include("resample.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function elbo_and_samples!(ϕ, rng, logp, dist; save_samples::Bool=true)
return ELBOEstimate(elbo, elbo_se, ϕ_save, logpϕ, logqϕ, logr)
end

struct ELBOEstimate{T,P,L<:AbstractVector{T}}
struct ELBOEstimate{T,P<:AbstractMatrix{T},L<:Vector{T}}
value::T
std_err::T
draws::P
Expand Down
6 changes: 6 additions & 0 deletions src/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ function fit_mvnormals(θs, logpθs, ∇logpθs; kwargs...)
return dists, num_bfgs_updates_rejected
end

function fit_mvnormal(state::LBFGSState)
Σ = state.invH
μ = muladd(Σ, state.∇fx, state.x)
return Distributions.MvNormal(μ, Σ)
end

# faster than computing `logpdf` and `rand!` independently
function rand_and_logpdf!(rng, dist::Distributions.MvNormal, x)
(; μ, Σ) = dist
Expand Down
154 changes: 130 additions & 24 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,73 @@ function optimize_with_trace(
maxiters=1_000,
callback=nothing,
fail_on_nonfinite=true,
ndraws_elbo::Int=5,
rng=Random.GLOBAL_RNG,
(invH_init!)=gilbert_invH_init!,
save_trace::Bool=true,
kwargs...,
)
u0 = prob.u0
fun = prob.f
if prob.f.grad === nothing
# Generate a cache to use Optimization's native functionality for adding missing
# gradient values
fun = Optimization.OptimizationCache(prob, optimizer).f
if fun.grad === nothing
throw(
ArgumentError(
"Gradient function is not available. Please provide an OptimizationProblem with an explicit gradient function.",
),
)
end
else
fun = prob.f
end

logp(x) = -fun.f(x, nothing)
function ∇logp(x)
SciMLBase.isinplace(fun) || return -fun.grad(x)
res = similar(x)
fun.grad(res, x)
rmul!(res, -1)
return res
end

# caches for the trace of x and f(x)
(; u0) = prob
T = eltype(u0)
xs = typeof(u0)[]
fxs = typeof(fun.f(u0, nothing))[]
∇fxs = Union{Nothing,typeof(u0)}[]
∇fxs = typeof(u0)[]
optim_trace = OptimizationTrace(xs, fxs, ∇fxs)
# TODO: fix this
lbfgs_state = LBFGSState(u0, -logp(u0), ∇logp(u0), 10)
draws_cache = similar(u0, size(u0, 1), ndraws_elbo)
elbo_estimates = ELBOEstimate{T,typeof(draws_cache),Vector{T}}[]
# TODO: make this a concrete type
fit_distributions = typeof(fit_mvnormal(lbfgs_state))[]

# TODO: keep deepcopy of ELBO-maximizing fit distribution so far, iteration where built,
# and maximum ELBO value

_callback = OptimizationCallback(
xs, fxs, ∇fxs, progress_name, progress_id, maxiters, callback, fail_on_nonfinite
logp,
∇logp,
rng,
save_trace,
maxiters,
fail_on_nonfinite,
callback,
lbfgs_state,
draws_cache,
optim_trace,
fit_distributions,
elbo_estimates,
invH_init!,
progress_name,
progress_id,
)
sol = Optimization.solve(prob, optimizer; callback=_callback, maxiters, kwargs...)

_∇fxs = _fill_missing_gradient_values!(∇fxs, xs, sol.cache.f)

return sol, OptimizationTrace(xs, fxs, _∇fxs)
return sol, optim_trace, fit_distributions, elbo_estimates[(begin + 1):end]
end

function _fill_missing_gradient_values!(∇fxs, xs, optim_fun)
Expand All @@ -87,37 +138,92 @@ function _fill_missing_gradient_values!(∇fxs, xs, optim_fun)
return convert(typeof(xs), ∇fxs)
end

struct OptimizationCallback{X,FX,∇FX,ID,CB}
xs::X
fxs::FX
∇fxs::∇FX
progress_name::String
progress_id::ID
struct OptimizationCallback{
F,
DF,
R<:Random.AbstractRNG,
CB,
L<:LBFGSState,
DC<:AbstractMatrix,
OT,
FD<:Vector{<:Distributions.MvNormal},
EE<:Vector,
IH,
ID,
}
# Generated functions
logp::F
∇logp::DF
# User-provided options
rng::R
save_trace::Bool
maxiters::Int
callback::CB
fail_on_nonfinite::Bool
callback::CB
# State/caches
lbfgs_state::L
draws_cache::DC
optim_trace::OT
fit_distributions::FD
elbo_estimates::EE
# Internally set options
invH_init!::IH
progress_name::String
progress_id::ID
end

function (cb::OptimizationCallback)(state::Optimization.OptimizationState, args...)
(;
xs, fxs, ∇fxs, progress_name, progress_id, maxiters, callback, fail_on_nonfinite
logp,
∇logp,
rng,
save_trace,
maxiters,
fail_on_nonfinite,
callback,
lbfgs_state,
optim_trace,
fit_distributions,
elbo_estimates,
draws_cache,
invH_init!,
progress_name,
progress_id,
) = cb
ret = callback !== nothing && callback(state, args...)
iteration = state.iter
Base.@logmsg ProgressLogging.ProgressLevel progress_name progress =
iteration / maxiters _id = progress_id
Base.@logmsg ProgressLogging.ProgressLevel progress_name progress = iteration / maxiters _id =
progress_id

# some optimizers mutate x, so we must copy it
x = copy(state.u)
fx = -state.objective
∇fx = state.grad === nothing ? nothing : -state.grad
logp_x = -state.objective
∇logp_x = state.grad === nothing ? ∇logp(x) : -state.grad

# Update L-BFGS state
ϵ = sqrt(eps(eltype(x)))
_update_state!(lbfgs_state, x, -logp_x, -∇logp_x, invH_init!, ϵ)

# some backends mutate x, so we must copy it
push!(xs, x)
push!(fxs, fx)
push!(∇fxs, ∇fx)
# Fit distribution
fit_distribution = fit_mvnormal(lbfgs_state)
elbo_estimate = elbo_and_samples!(
draws_cache, rng, logp, fit_distribution; save_samples=save_trace
)

push!(optim_trace.log_densities, logp_x)
push!(elbo_estimates, elbo_estimate)
if save_trace
push!(optim_trace.points, x)
push!(optim_trace.gradients, ∇logp_x)
push!(fit_distributions, fit_distribution)
end

if fail_on_nonfinite && !ret
ret = (isnan(fx) || fx == Inf || (∇fx !== nothing && any(!isfinite, ∇fx)))::Bool
ret = (
isnan(logp_x) ||
logp_x == -Inf ||
(∇logp_x !== nothing && any(!isfinite, ∇logp_x))
)::Bool
end

return ret
Expand Down
20 changes: 10 additions & 10 deletions src/singlepath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,24 +294,24 @@ function _pathfinder(
logp;
history_length::Int=DEFAULT_HISTORY_LENGTH,
optimizer=default_optimizer(history_length),
ndraws_elbo=DEFAULT_NDRAWS_ELBO,
executor::Transducers.Executor=Transducers.SequentialEx(),
kwargs...,
)
# compute trajectory
optim_solution, optim_trace = optimize_with_trace(prob, optimizer; kwargs...)
optim_solution, optim_trace, fit_distributions, elbo_estimates = optimize_with_trace(
prob, optimizer; rng, kwargs...
)
num_bfgs_updates_rejected = 0
L = length(optim_trace) - 1
success = L > 0

# fit mv-normal distributions to trajectory
fit_distributions, num_bfgs_updates_rejected = fit_mvnormals(
optim_trace.points, optim_trace.log_densities, optim_trace.gradients; history_length
)

# find ELBO-maximizing distribution
fit_iteration, elbo_estimates = @views maximize_elbo(
rng, logp, fit_distributions[(begin + 1):end], ndraws_elbo, executor
)
if isempty(elbo_estimates)
fit_iteration = 0
else
_, fit_iteration = _findmax(est.value for est in elbo_estimates)
end

if isempty(elbo_estimates)
success = false
else
Expand Down

0 comments on commit e02ad8b

Please sign in to comment.