Skip to content

Commit

Permalink
add projection for Bijectors with MvLocationScale
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 7, 2024
1 parent 48607c5 commit a54c7fc
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt
if isdefined(Base, :get_extension)
using AdvancedVI
using Bijectors
using LinearAlgebra
using Optimisers
using Random
else
using ..AdvancedVI
using ..Bijectors
using ..LinearAlgebra
using ..Optimisers
using ..Random
end

function AdvancedVI.update_variational_params!(
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
restructure,
grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
Expand Down

0 comments on commit a54c7fc

Please sign in to comment.