Skip to content

Commit

Permalink
add missing test file
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Dec 24, 2024
1 parent ebe0637 commit dcf21db
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions test/interface/clip_scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

@testset "interface ClipScale" begin
@testset "MvLocationScale" begin
@testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in
[:meanfield, :fullrank],
realtype in [Float32, Float64],
bijector in [nothing, :identity]

d = 5
μ = zeros(realtype, d)
ϵ = sqrt(realtype(0.5))
q = if covtype == :fullrank
L = LowerTriangular(Matrix{realtype}(I, d, d))
FullRankGaussian(μ, L)
elseif covtype == :meanfield
L = Diagonal(ones(realtype, d))
MeanFieldGaussian(μ, L)
end
q = if isnothing(bijector)
q
else
Bijectors.TransformedDistribution(q, identity)
end

params, re = Optimisers.destructure(q)
params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re)
q′ = re(params′)

if isnothing(bijector)
@test all(var(q′) .≥ ϵ^2)
else
@test all(var(q′.dist) .≥ ϵ^2)
end
end
end

@testset "MvLocationScaleLowRank" begin
@testset "$(realtype) $(bijector)" for realtype in [Float32, Float64],
bijector in [nothing, :identity]

n_rank = 2
d = 5
μ = zeros(realtype, d)
ϵ = sqrt(realtype(0.5))
D = ones(realtype, d)
U = randn(realtype, d, n_rank)
q = MvLocationScaleLowRank(
μ, D, U, Normal{realtype}(zero(realtype), one(realtype))
)
q = if isnothing(bijector)
q
else
Bijectors.TransformedDistribution(q, bijector)
end

params, re = Optimisers.destructure(q)
params′ = AdvancedVI.operate(ClipScale(ϵ), typeof(q), params, re)
q′ = re(params′)

if isnothing(bijector)
@test all(var(q′) .≥ ϵ^2)
else
@test all(var(q′.dist) .≥ ϵ^2)
end
end
end
end

0 comments on commit dcf21db

Please sign in to comment.