Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interface for projections/proximal operators #147

Merged
merged 23 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
238128e
refactor make scale projection operator its own optimization rule
Red-Portal Nov 17, 2024
03338d6
add docs for `ProjectScale`
Red-Portal Nov 17, 2024
233cffa
refactor change of type parameter order for `LocationScaleLowRank`
Red-Portal Nov 17, 2024
960d77d
apply formatter
Red-Portal Nov 17, 2024
db42115
apply formatter
Red-Portal Nov 17, 2024
6dd0fd6
apply formatter
Red-Portal Nov 17, 2024
074218a
update README
Red-Portal Nov 17, 2024
a11e5ce
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into pro…
Red-Portal Dec 9, 2024
a3ce1d1
fix formatting
Red-Portal Dec 9, 2024
ee36164
fix outdated type parameters in `LocationScale`
Red-Portal Dec 10, 2024
cd35e4e
rename averaging function
Red-Portal Dec 24, 2024
f40df75
fix projection/proximal operator interface
Red-Portal Dec 24, 2024
97f64e1
update documentation
Red-Portal Dec 24, 2024
9f1a549
fix formatting
Red-Portal Dec 24, 2024
ebe0637
fix benchmark
Red-Portal Dec 24, 2024
dcf21db
add missing test file
Red-Portal Dec 24, 2024
7868317
fix documentation
Red-Portal Dec 24, 2024
04db344
fix documentation
Red-Portal Dec 24, 2024
f731bdc
fix ambiguous specialization error for `operate`
Red-Portal Dec 24, 2024
86e1ab3
update documentation
Red-Portal Dec 27, 2024
1b3b734
refactor `average` and `operate` to specializations of `apply`
Red-Portal Dec 29, 2024
9887bb4
Merge branch 'projected_proximal_location_scale' of github.com:Turing…
Red-Portal Dec 29, 2024
635ea4e
Update docs/src/optimization.md
yebai Dec 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize(
max_iter;
adtype=ADTypes.AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
operator=ClipScale(),
)

# Evaluate final ELBO with 10^3 Monte Carlo samples
Expand Down
2 changes: 2 additions & 0 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ begin
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))
operator = ClipScale()

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
Expand Down Expand Up @@ -73,6 +74,7 @@ begin
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
operator=$operator,
show_progress=false,
)
end
Expand Down
3 changes: 3 additions & 0 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);

Expand All @@ -230,6 +231,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);

Expand Down Expand Up @@ -317,6 +319,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);

Expand Down
4 changes: 4 additions & 0 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
show_progress=false,
adtype=AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
operator=ClipScale(),
);
nothing
```

`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref clipscale).

`q_avg_trans` is the final output of the optimization procedure.
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.

Expand Down
15 changes: 15 additions & 0 deletions docs/src/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ PolynomialAveraging
[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973.
[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769.
[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR.
## Operators

Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update.
For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `AdvancedVI.optimize`.

### [`ClipScale`](@id clipscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following projection operator:

```@docs
ClipScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
26 changes: 20 additions & 6 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,38 @@ else
using ..Random
end

function AdvancedVI.update_variational_params!(
function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
restructure,
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps
ϵ = convert(eltype(params), op.epsilon)

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
return params
end

function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
params,
restructure,
)
q = restructure(params)
ϵ = convert(eltype(params), op.epsilon)

@. q.dist.scale_diag = max(q.dist.scale_diag, ϵ)

params, _ = Optimisers.destructure(q)

return params
end

function AdvancedVI.reparam_with_entropy(
Expand Down
65 changes: 35 additions & 30 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,6 @@
"""
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)

Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.

This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.

# Arguments
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `grad`: Gradient to be used by the update rule of `opt_st`.

# Returns
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

function update_variational_params!(::Type, opt_st, params, restructure, grad)
return Optimisers.update!(opt_st, params, grad)
end

# estimators
"""
AbstractVariationalObjective
Expand Down Expand Up @@ -149,7 +121,7 @@
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `restructure`: Function that reconstructs the variational approximation from `params`.
- `obj_state`: Previous state of the objective.

# Returns
Expand Down Expand Up @@ -215,7 +187,7 @@
init(::AbstractAverager, ::Any) = nothing

"""
apply(avg, avg_st, params)
apply(avg::AbstractAverager, avg_st, params)

Apply averaging strategy `avg` on `params` given the state `avg_st`.

Expand All @@ -241,6 +213,39 @@

export NoAveraging, PolynomialAveraging

# Operators for Optimization
abstract type AbstractOperator end

"""
apply(op::AbstractOperator, family, params, restructure)

Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator.

# Arguments
- `op::AbstractOperator`: Operator operating on the parameters `params`.
- `family::Type`: Type of the variational approximation `restructure(params)`.
- `params`: Variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `params`.

# Returns
- `oped_params`: Parameters resulting from applying the operator.
"""
function apply(::AbstractOperator, ::Type, ::Any, ::Any) end

Check warning on line 233 in src/AdvancedVI.jl

View check run for this annotation

Codecov / codecov/patch

src/AdvancedVI.jl#L233

Added line #L233 was not covered by tests

"""
IdentityOperator()

Identity operator.
"""
struct IdentityOperator <: AbstractOperator end

apply(::IdentityOperator, ::Type, params, restructure) = params

include("optimization/clip_scale.jl")

export IdentityOperator, ClipScale

# Main optimization routine
function optimize end

export optimize
Expand Down
76 changes: 16 additions & 60 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end

"""
MvLocationScale(location, scale, dist; scale_eps)
MvLocationScale(location, scale, dist)

The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.
Expand All @@ -20,21 +12,11 @@ represented as follows:
u = rand(dist, d)
z = scale*u + location
```

`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization.
This is necessary to guarantee stable convergence.

# Keyword Arguments
- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`).
"""
function MvLocationScale(
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousUnivariateDistribution;
scale_eps::T=T(1e-4),
) where {T<:Real}
@assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
return MvLocationScale(location, scale, dist, scale_eps)
struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution
location::L
scale::S
dist::D
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale)
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S<:Diagonal,D,L,E}
model::MvLocationScale{S,D,L,E}
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
return MvLocationScale(location, scale, re.model.dist)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
Expand All @@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
(; location, scale, dist) = q
Expand Down Expand Up @@ -131,55 +113,29 @@ function Distributions.cov(q::MvLocationScale)
end

"""
FullRankGaussian(μ, L; scale_eps)
FullRankGaussian(μ, L)

Construct a Gaussian variational approximation with a dense covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function FullRankGaussian(
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4)
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

"""
MeanFieldGaussian(μ, L; scale_eps)
MeanFieldGaussian(μ, L)

Construct a Gaussian variational approximation with a diagonal covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real}
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end
Loading
Loading