Skip to content

Commit

Permalink
refactor interface for projections/proximal operators (#147)
Browse files Browse the repository at this point in the history
* fix outdated type parameters in `LocationScale`

* add `operator` keyword argument to `optimize` so that projection/proximal operatord can have their own interface.

* fix benchmark

---------
 
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
Red-Portal authored Dec 30, 2024
1 parent 04a894a commit 54dff15
Show file tree
Hide file tree
Showing 22 changed files with 241 additions and 221 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize(
max_iter;
adtype=ADTypes.AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
operator=ClipScale(),
)

# Evaluate final ELBO with 10^3 Monte Carlo samples
Expand Down
2 changes: 2 additions & 0 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ begin
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))
operator = ClipScale()

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
Expand Down Expand Up @@ -73,6 +74,7 @@ begin
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
operator=$operator,
show_progress=false,
)
end
Expand Down
3 changes: 3 additions & 0 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand All @@ -230,6 +231,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand Down Expand Up @@ -317,6 +319,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand Down
4 changes: 4 additions & 0 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
show_progress=false,
adtype=AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
operator=ClipScale(),
);
nothing
```

`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref clipscale).

`q_avg_trans` is the final output of the optimization procedure.
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.

Expand Down
15 changes: 15 additions & 0 deletions docs/src/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ PolynomialAveraging
[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973.
[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769.
[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR.
## Operators

Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update.
For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `AdvancedVI.optimize`.

### [`ClipScale`](@id clipscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following projection operator:

```@docs
ClipScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
26 changes: 20 additions & 6 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,38 @@ else
using ..Random
end

function AdvancedVI.update_variational_params!(
function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
restructure,
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps
ϵ = convert(eltype(params), op.epsilon)

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
return params
end

function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
params,
restructure,
)
q = restructure(params)
ϵ = convert(eltype(params), op.epsilon)

@. q.dist.scale_diag = max(q.dist.scale_diag, ϵ)

params, _ = Optimisers.destructure(q)

return params
end

function AdvancedVI.reparam_with_entropy(
Expand Down
65 changes: 35 additions & 30 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,6 @@ This is an indirection for handling the type stability of `restructure`, as some
"""
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.
# Arguments
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `grad`: Gradient to be used by the update rule of `opt_st`.
# Returns
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

function update_variational_params!(::Type, opt_st, params, restructure, grad)
return Optimisers.update!(opt_st, params, grad)
end

# estimators
"""
AbstractVariationalObjective
Expand Down Expand Up @@ -149,7 +121,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `restructure`: Function that reconstructs the variational approximation from `params`.
- `obj_state`: Previous state of the objective.
# Returns
Expand Down Expand Up @@ -215,7 +187,7 @@ Initialize the state of the averaging strategy `avg` with the initial parameters
init(::AbstractAverager, ::Any) = nothing

"""
apply(avg, avg_st, params)
apply(avg::AbstractAverager, avg_st, params)
Apply averaging strategy `avg` on `params` given the state `avg_st`.
Expand All @@ -241,6 +213,39 @@ include("optimization/averaging.jl")

export NoAveraging, PolynomialAveraging

# Operators for Optimization
abstract type AbstractOperator end

"""
apply(op::AbstractOperator, family, params, restructure)
Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator.
# Arguments
- `op::AbstractOperator`: Operator operating on the parameters `params`.
- `family::Type`: Type of the variational approximation `restructure(params)`.
- `params`: Variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `params`.
# Returns
- `oped_params`: Parameters resulting from applying the operator.
"""
function apply(::AbstractOperator, ::Type, ::Any, ::Any) end

"""
IdentityOperator()
Identity operator.
"""
struct IdentityOperator <: AbstractOperator end

apply(::IdentityOperator, ::Type, params, restructure) = params

include("optimization/clip_scale.jl")

export IdentityOperator, ClipScale

# Main optimization routine
function optimize end

export optimize
Expand Down
76 changes: 16 additions & 60 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end

"""
MvLocationScale(location, scale, dist; scale_eps)
MvLocationScale(location, scale, dist)
The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.
Expand All @@ -20,21 +12,11 @@ represented as follows:
u = rand(dist, d)
z = scale*u + location
```
`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization.
This is necessary to guarantee stable convergence.
# Keyword Arguments
- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`).
"""
function MvLocationScale(
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousUnivariateDistribution;
scale_eps::T=T(1e-4),
) where {T<:Real}
@assert minimum(diag(scale)) scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
return MvLocationScale(location, scale, dist, scale_eps)
struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution
location::L
scale::S
dist::D
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale)
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S<:Diagonal,D,L,E}
model::MvLocationScale{S,D,L,E}
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
return MvLocationScale(location, scale, re.model.dist)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
Expand All @@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
(; location, scale, dist) = q
Expand Down Expand Up @@ -131,55 +113,29 @@ function Distributions.cov(q::MvLocationScale)
end

"""
FullRankGaussian(μ, L; scale_eps)
FullRankGaussian(μ, L)
Construct a Gaussian variational approximation with a dense covariance matrix.
# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function FullRankGaussian(
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4)
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

"""
MeanFieldGaussian(μ, L; scale_eps)
MeanFieldGaussian(μ, L)
Construct a Gaussian variational approximation with a diagonal covariance matrix.
# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
function MeanFieldGaussian::AbstractVector{T}, L::Diagonal{T}) where {T<:Real}
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end
Loading

2 comments on commit 54dff15

@Red-Portal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Breaking changes

  • Complete rewrite of AdvancedVI with major changes in the API. (Refer to general usage and the example.)

New Features

  • Added full-rank and low-rank covariance low-rank variational families. (See the docs.)
  • Added the sticking-the-landing control variate. (See the docs.)
  • Added the score gradient estimator of the ELBO gradient with the leave-one-out control variate (also known as VarGrad)
  • Added parameter averaging. (See the docs)
  • Added parameter-free optimization algorithms. (See the docs)

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122152

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 54dff15d363642d4a0a2cf977186a630332e3ed4
git push origin v0.3.0

Please sign in to comment.