Skip to content

Commit

Permalink
Minor Touches for ScoreGradELBO (#99)
Browse files Browse the repository at this point in the history
* fix move log density computation of ScoreGradELBO out of the AD path

* update change the `ScoreGradELBO` objective to be VarGrad underneath

* fix remove unnecessary import

* add basic tests for interface tests of variational objectives

* tweak stepsize for inference test of ScoreGradELBO

* add docstrings to elbo objective forward ad paths

* remove `n_montecarlo` option in the inference tests and just fix it

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent 227d58d commit 1dbf2ac
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 151 deletions.
7 changes: 0 additions & 7 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ function estimate_entropy(
-logpdf(q, mc_sample)
end
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end
29 changes: 27 additions & 2 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end
Expand Down Expand Up @@ -85,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_repgradelbo_ad_forward(params′, aux)
"""
estimate_repgradelbo_ad_forward(params, aux)
AD-guaranteed forward path of the reparameterization gradient objective.
# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.
# Auxiliary Information
`aux` should containt the following entries:
- `rng`: Random number generator.
- `obj`: The `RepGradELBO` objective.
- `problem`: The target `LogDensityProblem`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path).
"""
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, obj, problem, adtype, restructure, q_stop) = aux
q = restructure_ad_forward(adtype, restructure, params)
q = restructure_ad_forward(adtype, restructure, params)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
Expand Down
133 changes: 39 additions & 94 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,113 +1,63 @@

"""
ScoreGradELBO(n_samples; kwargs...)
Evidence lower-bound objective computed with score function gradients.
```math
\\begin{aligned}
\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right)
&\\=
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z)
\\right]
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
\\end{aligned}
```
To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.
```math
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right)
\\right]
```
Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate.
# Arguments
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective.
# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
- The target distribution and the variational approximation have the same support.
Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <:
AdvancedVI.AbstractVariationalObjective
entropy::EntropyEst
struct ScoreGradELBO <: AbstractVariationalObjective
n_samples::Int
baseline_window_size::Int
baseline_history::Vector{Float64}
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator=ClosedFormEntropy(),
baseline_window_size::Int=10,
baseline_history::Vector{Float64}=Float64[],
)
return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
end

function Base.show(io::IO, obj::ScoreGradELBO)
print(io, "ScoreGradELBO(entropy=")
print(io, obj.entropy)
print(io, ", n_samples=")
print(io, "ScoreGradELBO(n_samples=")
print(io, obj.n_samples)
print(io, ", baseline_window_size=")
print(io, obj.baseline_window_size)
return print(io, ")")
end

function compute_control_variate_baseline(history, window_size)
if length(history) == 0
return 1.0
end
min_index = max(1, length(history) - window_size)
return mean(history[min_index:end])
end

function estimate_energy_with_samples(
prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
fv_mean = mean(fv)
score_grad = mean(@. samples_logprob * (fv - baseline))
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
return fv_mean + (score_grad - score_grad_stop)
end

function estimate_objective(
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
samples = rand(rng, q, n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
return mean(ℓπ - ℓq)
end

function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples)
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
(; rng, obj, problem, adtype, restructure, q_stop) = aux
baseline = compute_control_variate_baseline(
obj.baseline_history, obj.baseline_window_size
)
q = restructure_ad_forward(adtype, restructure, params′)
samples_stop = rand(rng, q_stop, obj.n_samples)
entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop)
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
energy = estimate_energy_with_samples(
problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
elbo = energy + entropy
return -elbo
"""
estimate_scoregradelbo_ad_forward(params, aux)
AD-guaranteed forward path of the score gradient objective.
# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.
# Auxiliary Information
`aux` should containt the following entries:
- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path).
- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
"""
function estimate_scoregradelbo_ad_forward(params, aux)
(; samples_stop, logprob_stop, adtype, restructure) = aux
q = restructure_ad_forward(adtype, restructure, params)
ℓπ = logprob_stop
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
f = ℓq - ℓπ
return (mean(abs2, f) - mean(f)^2) / 2
end

function AdvancedVI.estimate_gradient!(
Expand All @@ -120,20 +70,15 @@ function AdvancedVI.estimate_gradient!(
restructure,
state,
)
q_stop = restructure(params)
aux = (
rng=rng,
adtype=adtype,
obj=obj,
problem=prob,
restructure=restructure,
q_stop=q_stop,
)
q = restructure(params)
samples = rand(rng, q, obj.n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
push!(obj.baseline_history, -nelbo)
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
elbo = mean(ℓπ - ℓq)
stat = (elbo=elbo,)
return out, nothing, stat
end
5 changes: 2 additions & 3 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ end
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_distributionsad

Expand Down
7 changes: 3 additions & 4 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@ else
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_locationscale

Expand Down
5 changes: 2 additions & 3 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding =>
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
RepGradELBO(10; entropy=StickingTheLandingEntropy()),
),
(adbackname, adtype) in AD_repgradelbo_locationscale_bijectors

Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ end
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_distributionsad

seed = (0x38bef07cf9cc549d)
Expand All @@ -29,7 +24,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_locationscale

seed = (0x38bef07cf9cc549d)
Expand All @@ -30,7 +25,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
Expand Down
9 changes: 2 additions & 7 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ end
[Float64, Float32],
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
#:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet.
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)),
(adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors

seed = (0x38bef07cf9cc549d)
Expand All @@ -30,7 +25,7 @@ end
(; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats

T = 1000
η = 1e-5
η = 1e-4
opt = Optimisers.Descent(realtype(η))

b = Bijectors.bijector(model)
Expand Down
Loading

0 comments on commit 1dbf2ac

Please sign in to comment.