Skip to content

Commit

Permalink
add indirection for update step, add projection for LocationScale
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 7, 2024
1 parent 95a83c3 commit 48607c5
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 20 deletions.
27 changes: 27 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
"""
function value_and_gradient! end

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.
# Arguments
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `grad`: Gradient to be used by the update rule of `opt_st`.
# Returns
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)

# estimators
"""
AbstractVariationalObjective
Expand Down
58 changes: 39 additions & 19 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@ represented as follows:
```
"""
struct MvLocationScale{
S, D <: ContinuousDistribution, L
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
location ::L
scale ::S
dist ::D
scale_eps::E
end

function MvLocationScale(
location ::AbstractVector{T},
scale ::AbstractMatrix{T},
dist ::ContinuousDistribution;
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
MvLocationScale(location, scale, dist, scale_eps)
end

Functors.@functor MvLocationScale (location, scale)
Expand Down Expand Up @@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale))
n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
Expand Down Expand Up @@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

"""
Expand All @@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T)),
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
::MvLocationScale, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end
4 changes: 3 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ function optimize(
stat = merge(stat, stat′)

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
opt_st, params = update_variational_params!(
typeof(q_init), opt_st, params, restructure, grad
)

if !isnothing(callback)
stat′ = callback(
Expand Down

0 comments on commit 48607c5

Please sign in to comment.