From 238128e463a2c3c23b640cb6d3d3a616fee86a53 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 16 Nov 2024 21:10:58 -0800 Subject: [PATCH 01/21] refactor make scale projection operator its own optimization rule --- bench/benchmarks.jl | 2 +- ext/AdvancedVIBijectorsExt.jl | 23 +++++- src/AdvancedVI.jl | 13 +-- src/families/location_scale.jl | 77 +++++++----------- src/families/location_scale_low_rank.jl | 81 +++++++------------ src/optimize.jl | 2 +- test/families/location_scale.jl | 25 ++++-- test/families/location_scale_low_rank.jl | 23 ++++-- test/inference/repgradelbo_locationscale.jl | 2 +- .../repgradelbo_locationscale_bijectors.jl | 2 +- test/inference/scoregradelbo_locationscale.jl | 2 +- .../scoregradelbo_locationscale_bijectors.jl | 2 +- test/interface/optimize.jl | 2 +- 13 files changed, 127 insertions(+), 129 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 9e18bd91f..cb1b0af1e 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -40,7 +40,7 @@ begin ] max_iter = 10^4 d = LogDensityProblems.dimension(prob) - optimizer = Optimisers.Adam(T(1e-3)) + optimizer = ProjectScale(Optimisers.Adam(T(1e-3))) for (objname, obj) in [ ("RepGradELBO", RepGradELBO(10)), diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index f66e0ea21..9ae8bf9da 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -16,6 +16,7 @@ else end function AdvancedVI.update_variational_params!( + proj::ProjectScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, opt_st, params, @@ -24,9 +25,8 @@ function AdvancedVI.update_variational_params!( ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) - ϵ = q.dist.scale_eps + ϵ = proj.scale_eps - # Project the scale matrix to the set of positive definite triangular matrices diag_idx = diagind(q.dist.scale) @. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ) @@ -35,6 +35,25 @@ function AdvancedVI.update_variational_params!( return opt_st, params end +function AdvancedVI.update_variational_params!( + proj::ProjectScale, + ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}}, + opt_st, + params, + restructure, + grad, +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = proj.scale_eps + + @. q.dist.scale_diag = max(q.dist.scale_diag, ϵ) + + params, _ = Optimisers.destructure(q) + + return opt_st, params +end + function AdvancedVI.reparam_with_entropy( rng::Random.AbstractRNG, q::Bijectors.TransformedDistribution, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1d0c4f502..4f4441a6e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -63,9 +63,9 @@ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restruct # Update for gradient descent step """ - update_variational_params!(family_type, opt_st, params, restructure, grad) + update_variational_params!(rule, family_type, opt_st, params, restructure, grad) -Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`. +Update variational distribution according to the update rule in the optimizer state `opt_st`, the optimizer given by `rule`, and the variational family type `family_type`. This is a wrapper around `Optimisers.update!` to provide some indirection. For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings. @@ -73,6 +73,7 @@ Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may Instead, the return values should be used. # Arguments +- `rule`: Optimization rule. - `family_type::Type`: Type of the variational family `typeof(restructure(params))`. - `opt_st`: Optimizer state returned by `Optimisers.setup`. - `params`: Current set of parameters to be updated. @@ -83,9 +84,9 @@ Instead, the return values should be used. - `opt_st`: Updated optimizer state. - `params`: Updated parameters. """ -function update_variational_params! end - -function update_variational_params!(::Type, opt_st, params, restructure, grad) +function update_variational_params!( + ::Optimisers.AbstractRule, family_type, opt_st, params, restructure, grad +) return Optimisers.update!(opt_st, params, grad) end @@ -186,7 +187,7 @@ include("objectives/elbo/repgradelbo.jl") include("objectives/elbo/scoregradelbo.jl") # Variational Families -export MvLocationScale, MeanFieldGaussian, FullRankGaussian +export MvLocationScale, MeanFieldGaussian, FullRankGaussian, ProjectScale include("families/location_scale.jl") diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 22af4b4a9..8f7887e8c 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -1,14 +1,6 @@ -struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: - ContinuousMultivariateDistribution - location::L - scale::S - dist::D - scale_eps::E -end - """ - MvLocationScale(location, scale, dist; scale_eps) + MvLocationScale(location, scale, dist) The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. @@ -20,21 +12,11 @@ represented as follows: u = rand(dist, d) z = scale*u + location ``` - -`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization. -This is necessary to guarantee stable convergence. - -# Keyword Arguments -- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`). """ -function MvLocationScale( - location::AbstractVector{T}, - scale::AbstractMatrix{T}, - dist::ContinuousUnivariateDistribution; - scale_eps::T=T(1e-4), -) where {T<:Real} - @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior." - return MvLocationScale(location, scale, dist, scale_eps) +struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution + location::L + scale::S + dist::D end Functors.@functor MvLocationScale (location, scale) @@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale) # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{S<:Diagonal,D,L,E} - model::MvLocationScale{S,D,L,E} +struct RestructureMeanField{S<:Diagonal,D,L} + model::MvLocationScale{S,D,L} end function (re::RestructureMeanField)(flat::AbstractVector) n_dims = div(length(flat), 2) location = first(flat, n_dims) scale = Diagonal(last(flat, n_dims)) - return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps) + return MvLocationScale(location, scale, re.model.dist) end -function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) return flat, RestructureMeanField(q) @@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location) Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q @@ -131,49 +113,52 @@ function Distributions.cov(q::MvLocationScale) end """ - FullRankGaussian(μ, L; scale_eps) + FullRankGaussian(μ, L) Construct a Gaussian variational approximation with a dense covariance matrix. # Arguments - `μ::AbstractVector{T}`: Mean of the Gaussian. - `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. - -# Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function FullRankGaussian( - μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4) + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T} ) where {T<:Real} - q_base = Normal{T}(zero(T), one(T)) - return MvLocationScale(μ, L, q_base, scale_eps) + return MvLocationScale(μ, L, Normal{T}(zero(T), one(T))) end """ - MeanFieldGaussian(μ, L; scale_eps) + MeanFieldGaussian(μ, L) Construct a Gaussian variational approximation with a diagonal covariance matrix. # Arguments - `μ::AbstractVector{T}`: Mean of the Gaussian. - `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. - -# Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ -function MeanFieldGaussian( - μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4) -) where {T<:Real} - q_base = Normal{T}(zero(T), one(T)) - return MvLocationScale(μ, L, q_base, scale_eps) +function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real} + return MvLocationScale(μ, L, Normal{T}(zero(T), one(T))) +end + +struct ProjectScale{Rule<:Optimisers.AbstractRule,F<:Real} <: Optimisers.AbstractRule + rule::Rule + scale_eps::F +end + +function ProjectScale(rule, scale_eps::Real=1e-5) + return ProjectScale{typeof(rule),typeof(scale_eps)}(rule, scale_eps) end +Optimisers.setup(proj::ProjectScale, x) = Optimisers.setup(proj.rule, x) + +Optimisers.init(proj::ProjectScale, x) = Optimisers.init(proj.rule, x) + function update_variational_params!( - ::Type{<:MvLocationScale}, opt_st, params, restructure, grad + proj::ProjectScale, ::Type{<:MvLocationScale}, opt_st, params, restructure, grad ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) - ϵ = q.scale_eps + ϵ = convert(eltype(params), proj.scale_eps) # Project the scale matrix to the set of positive definite triangular matrices diag_idx = diagind(q.scale) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index e2044142f..c5ffc96ec 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -1,16 +1,6 @@ -struct MvLocationScaleLowRank{ - L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution,E<:Real -} <: ContinuousMultivariateDistribution - location::L - scale_diag::SD - scale_factors::SF - dist::D - scale_eps::E -end - """ - MvLocationLowRankScale(location, scale_diag, scale_factors, dist; scale_eps) + MvLocationLowRankScale(location, scale_diag, scale_factors, dist) Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix. The rank is given by `size(scale_factors, 2)`. @@ -24,23 +14,14 @@ represented as follows: u_factors = rand(dist, r) z = scale_diag.*u_diag + scale_factors*u_factors + location ``` - -`scale_eps` sets a constraint on the smallest value of `scale_diag` to be enforced during optimization. -This is necessary to guarantee stable convergence. - -# Keyword Arguments -- `scale_eps`: Lower bound constraint for the values of scale_diag. (default: `sqrt(eps(T))`). """ -function MvLocationScaleLowRank( - location::AbstractVector{T}, - scale_diag::AbstractVector{T}, - scale_factors::AbstractMatrix{T}, - dist::ContinuousUnivariateDistribution; - scale_eps::T=T(1e-4), -) where {T<:Real} - @assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior." - @assert size(scale_factors, 1) == length(scale_diag) - return MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) +struct MvLocationScaleLowRank{ + L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution +} <: ContinuousMultivariateDistribution + location::L + scale_diag::SD + scale_factors::SF + dist::D end Functors.@functor MvLocationScaleLowRank (location, scale_diag, scale_factors) @@ -49,7 +30,7 @@ Base.length(q::MvLocationScaleLowRank) = length(q.location) Base.size(q::MvLocationScaleLowRank) = size(q.location) -Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D,E}}) where {L,SD,SF,D,E} = eltype(L) +Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D}}) where {L,SD,SF,D} = eltype(L) function StatsBase.entropy(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q @@ -95,8 +76,8 @@ function Distributions.rand(q::MvLocationScaleLowRank) end function Distributions.rand( - rng::AbstractRNG, q::MvLocationScaleLowRank{S,D,L}, num_samples::Int -) where {S,D,L} + rng::AbstractRNG, q::MvLocationScaleLowRank, num_samples::Int +) @unpack location, scale_diag, scale_factors, dist = q n_dims = length(location) n_factors = size(scale_factors, 2) @@ -140,23 +121,8 @@ function Distributions.cov(q::MvLocationScaleLowRank) return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors') end -function update_variational_params!( - ::Type{<:MvLocationScaleLowRank}, opt_st, params, restructure, grad -) - opt_st, params = Optimisers.update!(opt_st, params, grad) - q = restructure(params) - ϵ = q.scale_eps - - # Clip diagonal to guarantee positive definite covariance - @. q.scale_diag = max(q.scale_diag, ϵ) - - params, _ = Optimisers.destructure(q) - - return opt_st, params -end - """ - LowRankGaussian(μ, D, U; scale_eps) + LowRankGaussian(μ, D, U) Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix. @@ -164,13 +130,22 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `μ::AbstractVector{T}`: Mean of the Gaussian. - `D::Vector{T}`: Diagonal of the scale. - `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank. - -# Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ -function LowRankGaussian( - μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=T(1e-4) -) where {T<:Real} +function LowRankGaussian(μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) - return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps) + return MvLocationScaleLowRank(μ, D, U, q_base) +end + +function update_variational_params!( + proj::ProjectScale, ::Type{<:MvLocationScaleLowRank}, opt_st, params, restructure, grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = convert(eltype(params), proj.scale_eps) + + @. q.scale_diag = max(q.scale_diag, ϵ) + + params, _ = Optimisers.destructure(q) + + return opt_st, params end diff --git a/src/optimize.jl b/src/optimize.jl index 8ef9db764..18e4d29fd 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -88,7 +88,7 @@ function optimize( grad = DiffResults.gradient(grad_buf) opt_st, params = update_variational_params!( - typeof(q_init), opt_st, params, restructure, grad + optimizer, typeof(q_init), opt_st, params, restructure, grad ) avg_st = apply(averager, avg_st, params) diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index c112352e9..92d15f6ab 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -154,24 +154,33 @@ ϵ = sqrt(realtype(0.5)) q = if covtype == :fullrank L = LowerTriangular(Matrix{realtype}(I, d, d)) - FullRankGaussian(μ, L; scale_eps=ϵ) + FullRankGaussian(μ, L) elseif covtype == :meanfield L = Diagonal(ones(realtype, d)) - MeanFieldGaussian(μ, L; scale_eps=ϵ) + MeanFieldGaussian(μ, L) end - q_trans = if isnothing(bijector) + q = if isnothing(bijector) q else Bijectors.TransformedDistribution(q, identity) end - g = deepcopy(q) + q_cpy = deepcopy(q) λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + grad, _ = Optimisers.destructure(q_cpy) + opt = Descent(one(realtype)) + proj = ProjectScale(opt, ϵ) + opt_st = Optimisers.setup(proj, λ) + _, λ′ = AdvancedVI.update_variational_params!( + proj, typeof(q), opt_st, λ, re, grad + ) q′ = re(λ′) - @test all(var(q′) .≥ ϵ^2) + + if isnothing(bijector) + @test all(var(q′) .≥ ϵ^2) + else + @test all(var(q′.dist) .≥ ϵ^2) + end end end diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 2accb971b..7e1c49b66 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -158,21 +158,30 @@ D = ones(realtype, d) U = randn(realtype, d, n_rank) q = MvLocationScaleLowRank( - μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ + μ, D, U, Normal{realtype}(zero(realtype), one(realtype)) ) - q_trans = if isnothing(bijector) + q = if isnothing(bijector) q else Bijectors.TransformedDistribution(q, bijector) end - g = deepcopy(q) + q_cpy = deepcopy(q) λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + grad, _ = Optimisers.destructure(q_cpy) + opt = Descent(one(realtype)) + proj = ProjectScale(opt, ϵ) + opt_st = Optimisers.setup(proj, λ) + _, λ′ = AdvancedVI.update_variational_params!( + proj, typeof(q), opt_st, λ, re, grad + ) q′ = re(λ′) - @test all(var(q′) .≥ ϵ^2) + + if isnothing(bijector) + @test all(var(q′) .≥ ϵ^2) + else + @test all(var(q′.dist) .≥ ϵ^2) + end end end end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index d1f0d7e41..db89f3c42 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -36,7 +36,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(realtype(η)) + opt = ProjectScale(Optimisers.Descent(realtype(η))) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index e2a69d62e..338f56151 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -36,7 +36,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(realtype(η)) + opt = ProjectScale(Optimisers.Descent(realtype(η))) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 753999dee..60623d6fc 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -36,7 +36,7 @@ end T = 1000 η = 1e-5 - opt = Optimisers.Descent(realtype(η)) + opt = ProjectScale(Optimisers.Descent(realtype(η))) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index bee8234ab..7d638ff3c 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -34,7 +34,7 @@ end T = 1000 η = 1e-5 - opt = Optimisers.Descent(realtype(η)) + opt = ProjectScale(Optimisers.Descent(realtype(η))) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 268098b5b..d7294ccaf 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -15,7 +15,7 @@ using Test obj = RepGradELBO(10) adtype = AutoForwardDiff() - optimizer = Optimisers.Adam(1e-2) + optimizer = ProjectScale(Optimisers.Adam(1e-2)) averager = PolynomialAveraging() @testset "default_rng" begin From 03338d6333c18a8e40ab87fa6040c773d2e15cfc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 16 Nov 2024 21:58:47 -0800 Subject: [PATCH 02/21] add docs for `ProjectScale` --- docs/src/elbo/repgradelbo.md | 7 ++++--- docs/src/examples.md | 5 ++++- docs/src/families.md | 10 ++++++++++ src/families/location_scale.jl | 9 +++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index 10af6a521..cb089a383 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -219,7 +219,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), + optimizer = ProjectScale(Optimisers.Adam(3e-3)), callback = callback, ); @@ -230,7 +230,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), + optimizer = ProjectScale(Optimisers.Adam(3e-3)), callback = callback, ); @@ -265,6 +265,7 @@ Furthermore, in a lot of cases, a low-accuracy solution may be sufficient. [^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. [^KMG2024]: Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR. + ## Advanced Usage There are two major ways to customize the behavior of `RepGradELBO` @@ -317,7 +318,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), + optimizer = ProjectScale(Optimisers.Adam(3e-3)), callback = callback, ); diff --git a/docs/src/examples.md b/docs/src/examples.md index 15b8907a8..078c004ad 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -118,11 +118,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize( n_max_iter; show_progress=false, adtype=AutoForwardDiff(), - optimizer=Optimisers.Adam(1e-3), + optimizer=ProjectScale(Optimisers.Adam(1e-3)), ); nothing ``` +`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family. +For more information see [this section](@ref projectscale). + `q_avg_trans` is the final output of the optimization procedure. If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate. diff --git a/docs/src/families.md b/docs/src/families.md index e270acad8..113eb016c 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -56,6 +56,16 @@ FullRankGaussian MeanFieldGaussian ``` +### [Scale Projection Operator](@id projectscale) +For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. +To ensure this, we provide the following wrapper around optimization rule: + +```@docs +ProjectScale +``` + +[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. + ### Gaussian Variational Families ```julia diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 8f7887e8c..4193c0873 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -140,6 +140,15 @@ function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real return MvLocationScale(μ, L, Normal{T}(zero(T), one(T))) end +""" + ProjectScale(rule, scale_eps) + +Compose an optimization `rule` with a projection, where the projection ensures that a `LocationScale` or `LocationScaleLowRank` has a scale with eigenvalues larger than `scale_eps`. + +# Arguments +- `rule::Optimisers.AbstractRule`: Optimization rule to compose with the projection. +- `scale_eps::Real`: Lower bound on the eigenvalues of the scale matrix of the projection. +""" struct ProjectScale{Rule<:Optimisers.AbstractRule,F<:Real} <: Optimisers.AbstractRule rule::Rule scale_eps::F From 233cffac299fa72059fd806afc948847309be38b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 16 Nov 2024 21:58:53 -0800 Subject: [PATCH 03/21] refactor change of type parameter order for `LocationScaleLowRank` --- src/families/location_scale_low_rank.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index c5ffc96ec..0e3ed4c6f 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -16,7 +16,7 @@ represented as follows: ``` """ struct MvLocationScaleLowRank{ - L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution + D<:ContinuousDistribution,L,SD<:AbstractVector,SF<:AbstractMatrix } <: ContinuousMultivariateDistribution location::L scale_diag::SD @@ -30,7 +30,7 @@ Base.length(q::MvLocationScaleLowRank) = length(q.location) Base.size(q::MvLocationScaleLowRank) = size(q.location) -Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D}}) where {L,SD,SF,D} = eltype(L) +Base.eltype(::Type{<:MvLocationScaleLowRank{D,L,SD,SF}}) where {D,L,SD,SF} = eltype(L) function StatsBase.entropy(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q From 960d77d804e79be5f1456897237b02e07ea6beb2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 17 Nov 2024 01:06:26 -0500 Subject: [PATCH 04/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/families.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/families.md b/docs/src/families.md index 113eb016c..68ff4f678 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -65,7 +65,6 @@ ProjectScale ``` [^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. - ### Gaussian Variational Families ```julia From db42115b0c580e3790f9efc081ecb5dc9b6de6af Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 17 Nov 2024 01:06:32 -0500 Subject: [PATCH 05/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/families.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/families.md b/docs/src/families.md index 68ff4f678..1c6f64728 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -57,6 +57,7 @@ MeanFieldGaussian ``` ### [Scale Projection Operator](@id projectscale) + For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. To ensure this, we provide the following wrapper around optimization rule: From 6dd0fd6569db7c132c69aa9d5f8d27b642b83844 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 17 Nov 2024 01:06:40 -0500 Subject: [PATCH 06/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/elbo/repgradelbo.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index cb089a383..ff404a517 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -265,7 +265,6 @@ Furthermore, in a lot of cases, a low-accuracy solution may be sufficient. [^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. [^KMG2024]: Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR. - ## Advanced Usage There are two major ways to customize the behavior of `RepGradELBO` From 074218a2cf08f5acd9e71753fd27122ac12821be Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 16 Nov 2024 22:08:11 -0800 Subject: [PATCH 07/21] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bd000c75..bc1ef6bb7 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize( q_transformed, max_iter; adtype=ADTypes.AutoForwardDiff(), - optimizer=Optimisers.Adam(1e-3), + optimizer=ProjectScale(Optimisers.Adam(1e-3)), ) # Evaluate final ELBO with 10^3 Monte Carlo samples From a3ce1d16f4832e661e8ed7fb243fbee38035634e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 9 Dec 2024 00:02:49 -0500 Subject: [PATCH 08/21] fix formatting --- src/families/location_scale_low_rank.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index ba81862ba..e707f13d2 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -75,9 +75,7 @@ function Distributions.rand(q::MvLocationScaleLowRank) return scale_diag .* u_diag + scale_factors * u_fact + location end -function Distributions.rand( - rng::AbstractRNG, q::MvLocationScaleLowRank, num_samples::Int -) +function Distributions.rand(rng::AbstractRNG, q::MvLocationScaleLowRank, num_samples::Int) (; location, scale_diag, scale_factors, dist) = q n_dims = length(location) n_factors = size(scale_factors, 2) From ee36164fd8d729cdc5de944e3f3168fa6792320c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 9 Dec 2024 20:52:20 -0500 Subject: [PATCH 09/21] fix outdated type parameters in `LocationScale` --- src/families/location_scale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index b9bc99faa..256eeb859 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -37,7 +37,7 @@ function (re::RestructureMeanField)(flat::AbstractVector) return MvLocationScale(location, scale, re.model.dist) end -function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} (; location, scale, dist) = q flat = vcat(location, diag(scale)) return flat, RestructureMeanField(q) From cd35e4e27a3019b466a53528c2c07870e5455c58 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 06:31:03 -0500 Subject: [PATCH 10/21] rename averaging function --- src/optimization/averaging.jl | 4 ++-- src/optimize.jl | 6 ++---- test/interface/averaging.jl | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/optimization/averaging.jl b/src/optimization/averaging.jl index 19c375d81..ae69974a3 100644 --- a/src/optimization/averaging.jl +++ b/src/optimization/averaging.jl @@ -8,7 +8,7 @@ struct NoAveraging <: AbstractAverager end init(::NoAveraging, x) = x -apply(::NoAveraging, state, x) = x +average(::NoAveraging, state, x) = x value(::NoAveraging, state) = state @@ -41,7 +41,7 @@ PolynomialAveraging() = PolynomialAveraging(8) init(::PolynomialAveraging, x) = (x, 1) -function apply(avg::PolynomialAveraging, state, x::AbstractVector{T}) where {T} +function average(avg::PolynomialAveraging, state, x::AbstractVector{T}) where {T} eta = T(avg.eta) x_bar, t = state diff --git a/src/optimize.jl b/src/optimize.jl index 18e4d29fd..10106a3be 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -87,10 +87,8 @@ function optimize( stat = merge(stat, stat′) grad = DiffResults.gradient(grad_buf) - opt_st, params = update_variational_params!( - optimizer, typeof(q_init), opt_st, params, restructure, grad - ) - avg_st = apply(averager, avg_st, params) + opt_st, params = Optimisers.update!(opt_st, params, grad) + avg_st = average(averager, avg_st, params) if !isnothing(callback) averaged_params = value(averager, avg_st) diff --git a/test/interface/averaging.jl b/test/interface/averaging.jl index e7a23e5ef..3d9e17aa9 100644 --- a/test/interface/averaging.jl +++ b/test/interface/averaging.jl @@ -6,7 +6,7 @@ function simulate_sequence_average(realtype::Type{<:Real}, avg::AdvancedVI.Abstr xs_it = eachcol(xs) st = AdvancedVI.init(avg, first(xs_it)) for x in xs_it - st = AdvancedVI.apply(avg, st, x) + st = AdvancedVI.average(avg, st, x) end return AdvancedVI.value(avg, st), xs end From f40df752594a703298bedb075a44052783ddb42d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:17:42 -0500 Subject: [PATCH 11/21] fix projection/proximal operator interface --- README.md | 3 +- ext/AdvancedVIBijectorsExt.jl | 23 +++---- src/AdvancedVI.jl | 61 +++++++++---------- src/families/location_scale.jl | 38 ------------ src/families/location_scale_low_rank.jl | 14 ----- src/optimization/clip_scale.jl | 39 ++++++++++++ src/optimize.jl | 2 + test/families/location_scale.jl | 41 ------------- test/families/location_scale_low_rank.jl | 38 ------------ test/inference/repgradelbo_distributionsad.jl | 2 +- test/inference/repgradelbo_locationscale.jl | 5 +- .../repgradelbo_locationscale_bijectors.jl | 5 +- .../scoregradelbo_distributionsad.jl | 2 +- test/inference/scoregradelbo_locationscale.jl | 5 +- .../scoregradelbo_locationscale_bijectors.jl | 5 +- test/interface/optimize.jl | 2 +- test/interface/rules.jl | 4 +- test/runtests.jl | 1 + 18 files changed, 105 insertions(+), 185 deletions(-) create mode 100644 src/optimization/clip_scale.jl diff --git a/README.md b/README.md index 15c561209..e5d1ea7f2 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,8 @@ q_avg, _, stats, _ = AdvancedVI.optimize( q_transformed, max_iter; adtype=ADTypes.AutoForwardDiff(), - optimizer=ProjectScale(Optimisers.Adam(1e-3)), + optimizer=Optimisers.Adam(1e-3), + operator=ClipScale(), ) # Evaluate final ELBO with 10^3 Monte Carlo samples diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 9ae8bf9da..5417147f4 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -15,43 +15,38 @@ else using ..Random end -function AdvancedVI.update_variational_params!( - proj::ProjectScale, +function AdvancedVI.operate( + op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, - opt_st, params, restructure, - grad, ) - opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) - ϵ = proj.scale_eps + ϵ = convert(eltype(params), op.epsilon) + # Project the scale matrix to the set of positive definite triangular matrices diag_idx = diagind(q.dist.scale) @. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ) params, _ = Optimisers.destructure(q) - return opt_st, params + return params end -function AdvancedVI.update_variational_params!( - proj::ProjectScale, +function AdvancedVI.operate( + op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}}, - opt_st, params, restructure, - grad, ) - opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) - ϵ = proj.scale_eps + ϵ = convert(eltype(params), op.epsilon) @. q.dist.scale_diag = max(q.dist.scale_diag, ϵ) params, _ = Optimisers.destructure(q) - return opt_st, params + return params end function AdvancedVI.reparam_with_entropy( diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d9f12b7e2..acadbdf96 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -60,35 +60,6 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) -# Update for gradient descent step -""" - update_variational_params!(rule, family_type, opt_st, params, restructure, grad) - -Update variational distribution according to the update rule in the optimizer state `opt_st`, the optimizer given by `rule`, and the variational family type `family_type`. - -This is a wrapper around `Optimisers.update!` to provide some indirection. -For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings. -Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino. -Instead, the return values should be used. - -# Arguments -- `rule`: Optimization rule. -- `family_type::Type`: Type of the variational family `typeof(restructure(params))`. -- `opt_st`: Optimizer state returned by `Optimisers.setup`. -- `params`: Current set of parameters to be updated. -- `restructure`: Callable for restructuring the varitional distribution from `params`. -- `grad`: Gradient to be used by the update rule of `opt_st`. - -# Returns -- `opt_st`: Updated optimizer state. -- `params`: Updated parameters. -""" -function update_variational_params!( - ::Optimisers.AbstractRule, family_type, opt_st, params, restructure, grad -) - return Optimisers.update!(opt_st, params, grad) -end - # estimators """ AbstractVariationalObjective @@ -150,7 +121,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Variational parameters to evaluate the gradient on. -- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `restructure`: Function that reconstructs the variational approximation from `params`. - `obj_state`: Previous state of the objective. # Returns @@ -186,7 +157,7 @@ include("objectives/elbo/repgradelbo.jl") include("objectives/elbo/scoregradelbo.jl") # Variational Families -export MvLocationScale, MeanFieldGaussian, FullRankGaussian, ProjectScale +export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") @@ -242,6 +213,34 @@ include("optimization/averaging.jl") export NoAveraging, PolynomialAveraging +# Operators for Optimization +abstract type AbstractOperator end + +""" + operate(op, family, params, restructure) + +Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator. + +# Arguments +- `op::AbstractOperator`: Operator operating on the parameters `params`. +- `family::Type`: Type of the variational approximation `restructure(params)`. +- `params`: Variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `params`. + +# Returns +- `oped_params`: Parameters resulting from applying the operator. +""" +function operate end + +struct IdentityOperator <: AbstractOperator end + +operate(::IdentityOperator, family, params, restructure) = params + +include("optimization/clip_scale.jl") + +export IdentityOperator, ClipScale + +# Main optimization routine function optimize end export optimize diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 256eeb859..01ae057c5 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -139,41 +139,3 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real} return MvLocationScale(μ, L, Normal{T}(zero(T), one(T))) end - -""" - ProjectScale(rule, scale_eps) - -Compose an optimization `rule` with a projection, where the projection ensures that a `LocationScale` or `LocationScaleLowRank` has a scale with eigenvalues larger than `scale_eps`. - -# Arguments -- `rule::Optimisers.AbstractRule`: Optimization rule to compose with the projection. -- `scale_eps::Real`: Lower bound on the eigenvalues of the scale matrix of the projection. -""" -struct ProjectScale{Rule<:Optimisers.AbstractRule,F<:Real} <: Optimisers.AbstractRule - rule::Rule - scale_eps::F -end - -function ProjectScale(rule, scale_eps::Real=1e-5) - return ProjectScale{typeof(rule),typeof(scale_eps)}(rule, scale_eps) -end - -Optimisers.setup(proj::ProjectScale, x) = Optimisers.setup(proj.rule, x) - -Optimisers.init(proj::ProjectScale, x) = Optimisers.init(proj.rule, x) - -function update_variational_params!( - proj::ProjectScale, ::Type{<:MvLocationScale}, opt_st, params, restructure, grad -) - opt_st, params = Optimisers.update!(opt_st, params, grad) - q = restructure(params) - ϵ = convert(eltype(params), proj.scale_eps) - - # Project the scale matrix to the set of positive definite triangular matrices - diag_idx = diagind(q.scale) - @. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ) - - params, _ = Optimisers.destructure(q) - - return opt_st, params -end diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index e707f13d2..f3e444a30 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -133,17 +133,3 @@ function LowRankGaussian(μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}) wher q_base = Normal{T}(zero(T), one(T)) return MvLocationScaleLowRank(μ, D, U, q_base) end - -function update_variational_params!( - proj::ProjectScale, ::Type{<:MvLocationScaleLowRank}, opt_st, params, restructure, grad -) - opt_st, params = Optimisers.update!(opt_st, params, grad) - q = restructure(params) - ϵ = convert(eltype(params), proj.scale_eps) - - @. q.scale_diag = max(q.scale_diag, ϵ) - - params, _ = Optimisers.destructure(q) - - return opt_st, params -end diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl new file mode 100644 index 000000000..c2369df2e --- /dev/null +++ b/src/optimization/clip_scale.jl @@ -0,0 +1,39 @@ + +""" + ClipScale(ϵ = 1e-5) + +Apply a projection ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. +`ClipScale` also supports by operating on `MvLocationScale` and `MvLocationScaleLowRank` wrapped by a `Bijectors.TransformedDistribution` object. +""" +Optimisers.@def struct ClipScale <: AbstractOperator + epsilon = 1e-5 +end + +function operate(op::ClipScale, family::Type, params, restructure) + return error("`ClipScale` is not defined for the variational family of type $(family).") +end + +function operate(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) + q = restructure(params) + ϵ = convert(eltype(params), op.epsilon) + + # Project the scale matrix to the set of positive definite triangular matrices + diag_idx = diagind(q.scale) + @. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ) + + params, _ = Optimisers.destructure(q) + + return params +end + +function operate(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure) + q = restructure(params) + ϵ = convert(eltype(params), op.epsilon) + + # Project the scale matrix to the set of positive definite triangular matrices + @. q.scale_diag = max(q.scale_diag, ϵ) + + params, _ = Optimisers.destructure(q) + + return params +end diff --git a/src/optimize.jl b/src/optimize.jl index 10106a3be..272c80d22 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -57,6 +57,7 @@ function optimize( adtype::ADTypes.AbstractADType, optimizer::Optimisers.AbstractRule=Optimisers.Adam(), averager::AbstractAverager=NoAveraging(), + operator::AbstractOperator=IdentityOperator(), show_progress::Bool=true, state_init::NamedTuple=NamedTuple(), callback=nothing, @@ -88,6 +89,7 @@ function optimize( grad = DiffResults.gradient(grad_buf) opt_st, params = Optimisers.update!(opt_st, params, grad) + params = operate(operator, typeof(q_init), params, restructure) avg_st = average(averager, avg_st, params) if !isnothing(callback) diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index e9c0bda95..1ad8a5cf5 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -143,47 +143,6 @@ end end - @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in - [:meanfield, :fullrank], - realtype in [Float32, Float64], - bijector in [nothing, :identity] - - d = 5 - μ = zeros(realtype, d) - ϵ = sqrt(realtype(0.5)) - q = if covtype == :fullrank - L = LowerTriangular(Matrix{realtype}(I, d, d)) - FullRankGaussian(μ, L) - elseif covtype == :meanfield - L = Diagonal(ones(realtype, d)) - MeanFieldGaussian(μ, L) - end - q = if isnothing(bijector) - q - else - Bijectors.TransformedDistribution(q, identity) - end - q_cpy = deepcopy(q) - - λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(q_cpy) - opt = Descent(one(realtype)) - proj = ProjectScale(opt, ϵ) - opt_st = Optimisers.setup(proj, λ) - _, λ′ = AdvancedVI.update_variational_params!( - proj, typeof(q), opt_st, λ, re, grad - ) - q′ = re(λ′) - - if isnothing(bijector) - @test all(var(q′) .≥ ϵ^2) - else - @test all(var(q′.dist) .≥ ϵ^2) - end - end - end - @testset "Diagonal destructure" begin n_dims = 10 μ = zeros(n_dims) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 7e1c49b66..56bdd02ea 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -146,42 +146,4 @@ end end end - - @testset "diagonal positive definite projection" begin - @testset "$(realtype) $(bijector)" for realtype in [Float32, Float64], - bijector in [nothing, :identity] - - n_rank = 2 - d = 5 - μ = zeros(realtype, d) - ϵ = sqrt(realtype(0.5)) - D = ones(realtype, d) - U = randn(realtype, d, n_rank) - q = MvLocationScaleLowRank( - μ, D, U, Normal{realtype}(zero(realtype), one(realtype)) - ) - q = if isnothing(bijector) - q - else - Bijectors.TransformedDistribution(q, bijector) - end - q_cpy = deepcopy(q) - - λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(q_cpy) - opt = Descent(one(realtype)) - proj = ProjectScale(opt, ϵ) - opt_st = Optimisers.setup(proj, λ) - _, λ′ = AdvancedVI.update_variational_params!( - proj, typeof(q), opt_st, λ, re, grad - ) - q′ = re(λ′) - - if isnothing(bijector) - @test all(var(q′) .≥ ϵ^2) - else - @test all(var(q′.dist) .≥ ϵ^2) - end - end - end end diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 286011ad6..c0e3f6db9 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -29,7 +29,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(realtype(η)) + opt = Optimisers.Descent(η) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index a3aee8d03..6467681f4 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -30,7 +30,7 @@ end T = 1000 η = 1e-3 - opt = ProjectScale(Optimisers.Descent(realtype(η))) + opt = Optimisers.Descent(η) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -53,6 +53,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -75,6 +76,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -89,6 +91,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 9594016ce..7adba6398 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -30,7 +30,7 @@ end T = 1000 η = 1e-3 - opt = ProjectScale(Optimisers.Descent(realtype(η))) + opt = Optimisers.Descent(η) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -59,6 +59,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -81,6 +82,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -95,6 +97,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index c7aa9a44c..cea10b621 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -25,7 +25,7 @@ end T = 1000 η = 1e-4 - opt = Optimisers.Descent(realtype(η)) + opt = Optimisers.Descent(η) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 4b822e8cd..7a3964d1a 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -26,7 +26,7 @@ end T = 1000 η = 1e-4 - opt = ProjectScale(Optimisers.Descent(realtype(η))) + opt = Optimisers.Descent(η) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -49,6 +49,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -71,6 +72,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -85,6 +87,7 @@ end q0, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 8fa5cac0b..ed88ca086 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -26,7 +26,7 @@ end T = 1000 η = 1e-4 - opt = ProjectScale(Optimisers.Descent(realtype(η))) + opt = Optimisers.Descent(η) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -55,6 +55,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -77,6 +78,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) @@ -91,6 +93,7 @@ end q0_z, T; optimizer=opt, + operator=ClipScale(), show_progress=PROGRESS, adtype=adtype, ) diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index c51e39cdc..70944004c 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -15,7 +15,7 @@ using Test obj = RepGradELBO(10) adtype = AutoForwardDiff() - optimizer = ProjectScale(Optimisers.Adam(1e-2)) + optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() @testset "default_rng" begin diff --git a/test/interface/rules.jl b/test/interface/rules.jl index 39bee1d52..ed261f3d7 100644 --- a/test/interface/rules.jl +++ b/test/interface/rules.jl @@ -1,6 +1,8 @@ @testset "rules" begin - @testset "$(rule) $(realtype)" for rule in [DoWG(), DoG(), COCOB()], + @testset "$(rule) $(realtype)" for rule in [ + DoWG(), DoG(), COCOB(), DoWG(1e-5), DoG(1e-5), COCOB(100.0) + ], realtype in [Float32, Float64] T = 10^4 diff --git a/test/runtests.jl b/test/runtests.jl index 6a0bf7af2..2130d0e7b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,6 +49,7 @@ if TEST_GROUP == "All" || TEST_GROUP == "Interface" include("interface/rules.jl") include("interface/averaging.jl") include("interface/scoregradelbo.jl") + include("interface/clip_scale.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" From 97f64e13864479ab48957ad4e16b9e7399947130 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:24:48 -0500 Subject: [PATCH 12/21] update documentation --- docs/src/elbo/repgradelbo.md | 9 ++++++--- docs/src/examples.md | 7 ++++--- docs/src/families.md | 10 ---------- docs/src/optimization.md | 16 ++++++++++++++++ src/AdvancedVI.jl | 5 +++++ src/optimization/clip_scale.jl | 2 +- src/optimize.jl | 1 + 7 files changed, 33 insertions(+), 17 deletions(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index 59d5dc962..ccc69a97f 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -218,7 +218,8 @@ _, _, stats_cfe, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); @@ -229,7 +230,8 @@ _, _, stats_stl, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); @@ -316,7 +318,8 @@ _, _, stats_qmc, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoForwardDiff(), - optimizer = ProjectScale(Optimisers.Adam(3e-3)), + optimizer = Optimisers.Adam(3e-3), + operator = ClipScale(), callback = callback, ); diff --git a/docs/src/examples.md b/docs/src/examples.md index 9ecd3d26d..9c0292b81 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -117,13 +117,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize( n_max_iter; show_progress=false, adtype=AutoForwardDiff(), - optimizer=ProjectScale(Optimisers.Adam(1e-3)), + optimizer=Optimisers.Adam(1e-3), + operator=ClipScale(), ); nothing ``` -`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family. -For more information see [this section](@ref projectscale). +`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family. +For more information see [this section](@ref clipscale). `q_avg_trans` is the final output of the optimization procedure. If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate. diff --git a/docs/src/families.md b/docs/src/families.md index 1c6f64728..e270acad8 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -56,16 +56,6 @@ FullRankGaussian MeanFieldGaussian ``` -### [Scale Projection Operator](@id projectscale) - -For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. -To ensure this, we provide the following wrapper around optimization rule: - -```@docs -ProjectScale -``` - -[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. ### Gaussian Variational Families ```julia diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 05fe035d4..af2b99fde 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -26,3 +26,19 @@ PolynomialAveraging [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. [^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. [^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. + +## Operators + +Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. +For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`. + +### `ClipScale` (@id clipscale) + +For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. +To ensure this, we provide the following projection operator: + +```@docs +ClipScale +``` + +[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index acadbdf96..6255392bc 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -232,6 +232,11 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c """ function operate end +""" + IdentityOperator() + +Identity operator. +""" struct IdentityOperator <: AbstractOperator end operate(::IdentityOperator, family, params, restructure) = params diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index c2369df2e..a51bc2928 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -2,7 +2,7 @@ """ ClipScale(ϵ = 1e-5) -Apply a projection ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. +Projection operator ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. `ClipScale` also supports by operating on `MvLocationScale` and `MvLocationScaleLowRank` wrapped by a `Bijectors.TransformedDistribution` object. """ Optimisers.@def struct ClipScale <: AbstractOperator diff --git a/src/optimize.jl b/src/optimize.jl index 272c80d22..93e2afff4 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -17,6 +17,7 @@ This requires the variational approximation to be marked as a functor through `F - `adtype::ADtypes.AbstractADType`: Automatic differentiation backend. - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) - `averager::AbstractAverager` : Parameter averaging strategy. (Default: `NoAveraging()`) +- `operator::AbstractOperator` : Operator applied to the parameters after each optimization step. (Default: `IdentityOperator()`) - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) - `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) From 9f1a54947d6ac477a86909c7173eee946991bee6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:30:39 -0500 Subject: [PATCH 13/21] fix formatting --- docs/src/optimization.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/optimization.md b/docs/src/optimization.md index af2b99fde..5acace13e 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -26,8 +26,7 @@ PolynomialAveraging [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. [^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. [^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. - -## Operators +## Operators Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`. From ebe06378864dc74ebf6a7bbe6ad7614af09389b7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:38:25 -0500 Subject: [PATCH 14/21] fix benchmark --- bench/benchmarks.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index cb1b0af1e..5e5f11766 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -40,7 +40,8 @@ begin ] max_iter = 10^4 d = LogDensityProblems.dimension(prob) - optimizer = ProjectScale(Optimisers.Adam(T(1e-3))) + optimizer = Optimisers.Adam(T(1e-3)) + operator = ClipScale() for (objname, obj) in [ ("RepGradELBO", RepGradELBO(10)), @@ -73,6 +74,7 @@ begin $max_iter; adtype=$adtype, optimizer=$optimizer, + operator=$operator, show_progress=false, ) end From dcf21db42dd71904b8662a1f2703ded7e6c1514c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:39:14 -0500 Subject: [PATCH 15/21] add missing test file --- test/interface/clip_scale.jl | 67 ++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/interface/clip_scale.jl diff --git a/test/interface/clip_scale.jl b/test/interface/clip_scale.jl new file mode 100644 index 000000000..29181cbf0 --- /dev/null +++ b/test/interface/clip_scale.jl @@ -0,0 +1,67 @@ + +@testset "interface ClipScale" begin + @testset "MvLocationScale" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + q = if covtype == :fullrank + L = LowerTriangular(Matrix{realtype}(I, d, d)) + FullRankGaussian(μ, L) + elseif covtype == :meanfield + L = Diagonal(ones(realtype, d)) + MeanFieldGaussian(μ, L) + end + q = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + + params, re = Optimisers.destructure(q) + params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re) + q′ = re(params′) + + if isnothing(bijector) + @test all(var(q′) .≥ ϵ^2) + else + @test all(var(q′.dist) .≥ ϵ^2) + end + end + end + + @testset "MvLocationScaleLowRank" begin + @testset "$(realtype) $(bijector)" for realtype in [Float32, Float64], + bijector in [nothing, :identity] + + n_rank = 2 + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + D = ones(realtype, d) + U = randn(realtype, d, n_rank) + q = MvLocationScaleLowRank( + μ, D, U, Normal{realtype}(zero(realtype), one(realtype)) + ) + q = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, bijector) + end + + params, re = Optimisers.destructure(q) + params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re) + q′ = re(params′) + + if isnothing(bijector) + @test all(var(q′) .≥ ϵ^2) + else + @test all(var(q′.dist) .≥ ϵ^2) + end + end + end +end From 78683176bc9944a791b060b97c354796786605ac Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:47:37 -0500 Subject: [PATCH 16/21] fix documentation --- docs/src/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 5acace13e..4a0a5bcf2 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -31,7 +31,7 @@ PolynomialAveraging Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`. -### `ClipScale` (@id clipscale) +### [`ClipScale`](@id clipscale) For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. To ensure this, we provide the following projection operator: From 04db344df890797184b01fefa03c6b66cbae20fc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 09:48:27 -0500 Subject: [PATCH 17/21] fix documentation --- src/AdvancedVI.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 6255392bc..c1516117e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -187,7 +187,7 @@ Initialize the state of the averaging strategy `avg` with the initial parameters init(::AbstractAverager, ::Any) = nothing """ - apply(avg, avg_st, params) + average(avg, avg_st, params) Apply averaging strategy `avg` on `params` given the state `avg_st`. @@ -196,7 +196,7 @@ Apply averaging strategy `avg` on `params` given the state `avg_st`. - `avg_st`: Previous state of the averaging strategy. - `params`: Initial variational parameters. """ -function apply(::AbstractAverager, ::Any, ::Any) end +function average(::AbstractAverager, ::Any, ::Any) end """ value(avg, avg_st) @@ -230,7 +230,7 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c # Returns - `oped_params`: Parameters resulting from applying the operator. """ -function operate end +function operate(::AbstractOperator, ::Type, ::Any, ::Any) end """ IdentityOperator() From f731bdc703e7ec268b73c5f6262b78949ef6308c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Dec 2024 10:03:04 -0500 Subject: [PATCH 18/21] fix ambiguous specialization error for `operate` --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c1516117e..85dad9fd2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -239,7 +239,7 @@ Identity operator. """ struct IdentityOperator <: AbstractOperator end -operate(::IdentityOperator, family, params, restructure) = params +operate(::IdentityOperator, ::Type, params, restructure) = params include("optimization/clip_scale.jl") From 86e1ab361fa16b03a9c8b874573d5b3d433c6e5e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Dec 2024 02:35:36 -0500 Subject: [PATCH 19/21] update documentation Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- docs/src/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 4a0a5bcf2..422c9db86 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -29,7 +29,7 @@ PolynomialAveraging ## Operators Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. -For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`. +For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `AdvancedVI.optimize`. ### [`ClipScale`](@id clipscale) From 1b3b7346db8f67b0eab9edb331a5f8e6a0cb5684 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Dec 2024 16:05:59 -0500 Subject: [PATCH 20/21] refactor `average` and `operate` to specializations of `apply` --- docs/src/optimization.md | 1 + ext/AdvancedVIBijectorsExt.jl | 4 ++-- src/AdvancedVI.jl | 10 +++++----- src/optimization/averaging.jl | 4 ++-- src/optimization/clip_scale.jl | 6 +++--- src/optimize.jl | 4 ++-- test/interface/averaging.jl | 2 +- test/interface/clip_scale.jl | 4 ++-- 8 files changed, 18 insertions(+), 17 deletions(-) diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 4a0a5bcf2..985fd8e90 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -26,6 +26,7 @@ PolynomialAveraging [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. [^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. [^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. + ## Operators Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update. diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 5417147f4..1f414b6cd 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -15,7 +15,7 @@ else using ..Random end -function AdvancedVI.operate( +function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, params, @@ -33,7 +33,7 @@ function AdvancedVI.operate( return params end -function AdvancedVI.operate( +function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}}, params, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 85dad9fd2..31285d302 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -187,7 +187,7 @@ Initialize the state of the averaging strategy `avg` with the initial parameters init(::AbstractAverager, ::Any) = nothing """ - average(avg, avg_st, params) + apply(avg::AbstractAverager, avg_st, params) Apply averaging strategy `avg` on `params` given the state `avg_st`. @@ -196,7 +196,7 @@ Apply averaging strategy `avg` on `params` given the state `avg_st`. - `avg_st`: Previous state of the averaging strategy. - `params`: Initial variational parameters. """ -function average(::AbstractAverager, ::Any, ::Any) end +function apply(::AbstractAverager, ::Any, ::Any) end """ value(avg, avg_st) @@ -217,7 +217,7 @@ export NoAveraging, PolynomialAveraging abstract type AbstractOperator end """ - operate(op, family, params, restructure) + apply(op::AbstractOperator, family, params, restructure) Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator. @@ -230,7 +230,7 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c # Returns - `oped_params`: Parameters resulting from applying the operator. """ -function operate(::AbstractOperator, ::Type, ::Any, ::Any) end +function apply(::AbstractOperator, ::Type, ::Any, ::Any) end """ IdentityOperator() @@ -239,7 +239,7 @@ Identity operator. """ struct IdentityOperator <: AbstractOperator end -operate(::IdentityOperator, ::Type, params, restructure) = params +apply(::IdentityOperator, ::Type, params, restructure) = params include("optimization/clip_scale.jl") diff --git a/src/optimization/averaging.jl b/src/optimization/averaging.jl index ae69974a3..19c375d81 100644 --- a/src/optimization/averaging.jl +++ b/src/optimization/averaging.jl @@ -8,7 +8,7 @@ struct NoAveraging <: AbstractAverager end init(::NoAveraging, x) = x -average(::NoAveraging, state, x) = x +apply(::NoAveraging, state, x) = x value(::NoAveraging, state) = state @@ -41,7 +41,7 @@ PolynomialAveraging() = PolynomialAveraging(8) init(::PolynomialAveraging, x) = (x, 1) -function average(avg::PolynomialAveraging, state, x::AbstractVector{T}) where {T} +function apply(avg::PolynomialAveraging, state, x::AbstractVector{T}) where {T} eta = T(avg.eta) x_bar, t = state diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index a51bc2928..68aac072a 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -9,11 +9,11 @@ Optimisers.@def struct ClipScale <: AbstractOperator epsilon = 1e-5 end -function operate(op::ClipScale, family::Type, params, restructure) +function apply(::ClipScale, family::Type, params, restructure) return error("`ClipScale` is not defined for the variational family of type $(family).") end -function operate(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) +function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) @@ -26,7 +26,7 @@ function operate(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) return params end -function operate(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure) +function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) diff --git a/src/optimize.jl b/src/optimize.jl index 93e2afff4..5c65ad777 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -90,8 +90,8 @@ function optimize( grad = DiffResults.gradient(grad_buf) opt_st, params = Optimisers.update!(opt_st, params, grad) - params = operate(operator, typeof(q_init), params, restructure) - avg_st = average(averager, avg_st, params) + params = apply(operator, typeof(q_init), params, restructure) + avg_st = apply(averager, avg_st, params) if !isnothing(callback) averaged_params = value(averager, avg_st) diff --git a/test/interface/averaging.jl b/test/interface/averaging.jl index 3d9e17aa9..e7a23e5ef 100644 --- a/test/interface/averaging.jl +++ b/test/interface/averaging.jl @@ -6,7 +6,7 @@ function simulate_sequence_average(realtype::Type{<:Real}, avg::AdvancedVI.Abstr xs_it = eachcol(xs) st = AdvancedVI.init(avg, first(xs_it)) for x in xs_it - st = AdvancedVI.average(avg, st, x) + st = AdvancedVI.apply(avg, st, x) end return AdvancedVI.value(avg, st), xs end diff --git a/test/interface/clip_scale.jl b/test/interface/clip_scale.jl index 29181cbf0..d9a6330ce 100644 --- a/test/interface/clip_scale.jl +++ b/test/interface/clip_scale.jl @@ -23,7 +23,7 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) q′ = re(params′) if isnothing(bijector) @@ -54,7 +54,7 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) q′ = re(params′) if isnothing(bijector) From 635ea4e771ce87ba02454ba0820daf2aac05d603 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sun, 29 Dec 2024 22:53:45 +0000 Subject: [PATCH 21/21] Update docs/src/optimization.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/optimization.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/optimization.md b/docs/src/optimization.md index b70f30e52..422c9db86 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -26,7 +26,6 @@ PolynomialAveraging [^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973. [^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769. [^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR. - ## Operators Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update.