Skip to content

Commit d015eb5

Browse files
authored
Enable Enzyme (#67)
* fix enzyme to match new interface, enable enzyme tests * fix type instability tighten Enzyme compat * add indirection to enforce type stability of `restructure` * fix tests enable Enzyme inference tests only on 1.10
1 parent 39d506b commit d015eb5

14 files changed

+129
-45
lines changed

.github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
matrix:
2121
version:
2222
- '1.7'
23+
- '1.10'
2324
os:
2425
- ubuntu-latest
2526
- macOS-latest

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ ChainRulesCore = "1.16"
4242
DiffResults = "1"
4343
Distributions = "0.25.87"
4444
DocStringExtensions = "0.8, 0.9"
45-
Enzyme = "0.12"
45+
Enzyme = "0.12.32"
4646
FillArrays = "1.3"
4747
ForwardDiff = "0.10.36"
4848
Functors = "0.4"

ext/AdvancedVIEnzymeExt.jl

+31-5
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,39 @@ else
1111
using ..AdvancedVI: ADTypes, DiffResults
1212
end
1313

14+
function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params)
15+
return restructure(params)::typeof(restructure.model)
16+
end
17+
18+
function AdvancedVI.value_and_gradient!(
19+
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
20+
)
21+
Enzyme.API.runtimeActivity!(true)
22+
∇x = DiffResults.gradient(out)
23+
fill!(∇x, zero(eltype(∇x)))
24+
_, y = Enzyme.autodiff(
25+
Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, ∇x)
26+
)
27+
DiffResults.value!(out, y)
28+
return out
29+
end
30+
1431
function AdvancedVI.value_and_gradient!(
15-
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
16-
) where {T<:Real}
17-
∇θ = DiffResults.gradient(out)
18-
fill!(∇θ, zero(T))
32+
::ADTypes.AutoEnzyme,
33+
f,
34+
x::AbstractVector{<:Real},
35+
aux,
36+
out::DiffResults.MutableDiffResult,
37+
)
38+
Enzyme.API.runtimeActivity!(true)
39+
∇x = DiffResults.gradient(out)
40+
fill!(∇x, zero(eltype(∇x)))
1941
_, y = Enzyme.autodiff(
20-
Enzyme.ReverseWithPrimal, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)
42+
Enzyme.ReverseWithPrimal,
43+
Enzyme.Const(f),
44+
Enzyme.Active,
45+
Enzyme.Duplicated(x, ∇x),
46+
Enzyme.Const(aux),
2147
)
2248
DiffResults.value!(out, y)
2349
return out

src/AdvancedVI.jl

+7-8
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,17 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
4141
function value_and_gradient! end
4242

4343
"""
44-
stop_gradient(x)
44+
restructure_ad_forward(adtype, restructure, params)
4545
46-
Stop the gradient from propagating to `x` if the selected ad backend supports it.
47-
Otherwise, it is equivalent to `identity`.
46+
Apply `restructure` to `params`.
47+
This is an indirection for handling the type stability of `restructure`, as some AD backends require strict type stability in the AD path.
4848
4949
# Arguments
50-
- `x`: Input
51-
52-
# Returns
53-
- `x`: Same value as the input.
50+
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
51+
- `restructure`: Callable for restructuring the varitional distribution from `params`.
52+
- `params`: Variational Parameters.
5453
"""
55-
function stop_gradient end
54+
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)
5655

5756
# Update for gradient descent step
5857
"""

src/families/location_scale.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ Functors.@functor MvLocationScale (location, scale)
3838
# is very inefficient.
3939
# begin
4040
struct RestructureMeanField{S<:Diagonal,D,L}
41-
q::MvLocationScale{S,D,L}
41+
model::MvLocationScale{S,D,L}
4242
end
4343

4444
function (re::RestructureMeanField)(flat::AbstractVector)
4545
n_dims = div(length(flat), 2)
4646
location = first(flat, n_dims)
4747
scale = Diagonal(last(flat, n_dims))
48-
return MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
48+
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
4949
end
5050

5151
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}

src/objectives/elbo/repgradelbo.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
9393
end
9494

9595
function estimate_repgradelbo_ad_forward(params′, aux)
96-
@unpack rng, obj, problem, restructure, q_stop = aux
97-
q = restructure(params′)
96+
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
97+
q = restructure_ad_forward(adtype, restructure, params′)
9898
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
9999
energy = estimate_energy_with_samples(problem, samples)
100100
elbo = energy + entropy
@@ -112,7 +112,14 @@ function estimate_gradient!(
112112
state,
113113
)
114114
q_stop = restructure(params)
115-
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
115+
aux = (
116+
rng=rng,
117+
adtype=adtype,
118+
obj=obj,
119+
problem=prob,
120+
restructure=restructure,
121+
q_stop=q_stop,
122+
)
116123
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
117124
nelbo = DiffResults.value(out)
118125
stat = (elbo=-nelbo,)

test/Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2626
[compat]
2727
ADTypes = "0.2.1, 1"
2828
Bijectors = "0.13"
29+
DiffResults = "1.0"
2930
Distributions = "0.25.100"
3031
DistributionsAD = "0.6.45"
31-
Enzyme = "0.12"
32+
Enzyme = "0.12.32"
3233
FillArrays = "1.6.1"
3334
ForwardDiff = "0.10.36"
3435
Functors = "0.4.5"

test/inference/repgradelbo_distributionsad.jl

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
11

2+
AD_distributionsad = if VERSION >= v"1.10"
3+
Dict(
4+
:ForwarDiff => AutoForwardDiff(),
5+
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
6+
:Zygote => AutoZygote(),
7+
:Enzyme => AutoEnzyme(),
8+
)
9+
else
10+
Dict(
11+
:ForwarDiff => AutoForwardDiff(),
12+
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
13+
:Zygote => AutoZygote(),
14+
)
15+
end
16+
217
@testset "inference RepGradELBO DistributionsAD" begin
318
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
419
[Float64, Float32],
@@ -9,12 +24,7 @@
924
:RepGradELBOStickingTheLanding =>
1025
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
1126
),
12-
(adbackname, adtype) in Dict(
13-
:ForwarDiff => AutoForwardDiff(),
14-
#:ReverseDiff => AutoReverseDiff(),
15-
:Zygote => AutoZygote(),
16-
#:Enzyme => AutoEnzyme(),
17-
)
27+
(adbackname, adtype) in AD_distributionsad
1828

1929
seed = (0x38bef07cf9cc549d)
2030
rng = StableRNG(seed)
@@ -31,8 +41,8 @@
3141
# where ρ = 1 - ημ, μ is the strong convexity constant.
3242
contraction_rate = 1 - η * strong_convexity
3343

34-
μ0 = Zeros(realtype, n_dims)
35-
L0 = Diagonal(Ones(realtype, n_dims))
44+
μ0 = zeros(realtype, n_dims)
45+
L0 = Diagonal(ones(realtype, n_dims))
3646
q0 = TuringDiagMvNormal(μ0, diag(L0))
3747

3848
@testset "convergence" begin

test/inference/repgradelbo_locationscale.jl

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
11

2+
AD_locationscale = if VERSION >= v"1.10"
3+
Dict(
4+
:ForwarDiff => AutoForwardDiff(),
5+
:ReverseDiff => AutoReverseDiff(),
6+
:Zygote => AutoZygote(),
7+
:Enzyme => AutoEnzyme(),
8+
)
9+
else
10+
Dict(
11+
:ForwarDiff => AutoForwardDiff(),
12+
:ReverseDiff => AutoReverseDiff(),
13+
:Zygote => AutoZygote(),
14+
)
15+
end
16+
217
@testset "inference RepGradELBO VILocationScale" begin
318
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
419
[Float64, Float32],
@@ -10,12 +25,7 @@
1025
:RepGradELBOStickingTheLanding =>
1126
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
1227
),
13-
(adbackname, adtype) in Dict(
14-
:ForwarDiff => AutoForwardDiff(),
15-
:ReverseDiff => AutoReverseDiff(),
16-
:Zygote => AutoZygote(),
17-
#:Enzyme => AutoEnzyme(),
18-
)
28+
(adbackname, adtype) in AD_locationscale
1929

2030
seed = (0x38bef07cf9cc549d)
2131
rng = StableRNG(seed)

test/inference/repgradelbo_locationscale_bijectors.jl

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
11

2+
AD_locationscale_bijectors = if VERSION >= v"1.10"
3+
Dict(
4+
:ForwarDiff => AutoForwardDiff(),
5+
:ReverseDiff => AutoReverseDiff(),
6+
:Zygote => AutoZygote(),
7+
:Enzyme => AutoEnzyme(),
8+
)
9+
else
10+
Dict(
11+
:ForwarDiff => AutoForwardDiff(),
12+
:ReverseDiff => AutoReverseDiff(),
13+
:Zygote => AutoZygote(),
14+
)
15+
end
16+
217
@testset "inference RepGradELBO VILocationScale Bijectors" begin
318
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
419
[Float64, Float32],
@@ -10,12 +25,7 @@
1025
:RepGradELBOStickingTheLanding =>
1126
RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
1227
),
13-
(adbackname, adtype) in Dict(
14-
:ForwarDiff => AutoForwardDiff(),
15-
:ReverseDiff => AutoReverseDiff(),
16-
#:Zygote => AutoZygote(),
17-
#:Enzyme => AutoEnzyme(),
18-
)
28+
(adbackname, adtype) in AD_locationscale_bijectors
1929

2030
seed = (0x38bef07cf9cc549d)
2131
rng = StableRNG(seed)

test/interface/ad.jl

+19
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,23 @@ using Test
1919
@test (A + A') * λ / 2
2020
@test f λ' * A * λ / 2
2121
end
22+
23+
@testset "$(adname) with auxiliary input" for (adname, adsymbol) in Dict(
24+
:ForwardDiff => AutoForwardDiff(),
25+
:ReverseDiff => AutoReverseDiff(),
26+
:Zygote => AutoZygote(),
27+
:Enzyme => AutoEnzyme(),
28+
)
29+
D = 10
30+
A = randn(D, D)
31+
λ = randn(D)
32+
b = randn(D)
33+
grad_buf = DiffResults.GradientResult(λ)
34+
f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′)
35+
AdvancedVI.value_and_gradient!(adsymbol, f, λ, (b=b,), grad_buf)
36+
= DiffResults.gradient(grad_buf)
37+
f = DiffResults.value(grad_buf)
38+
@test (A + A') * λ / 2 + b
39+
@test f λ' * A * λ / 2 + dot(b, λ)
40+
end
2241
end

test/interface/repgradelbo.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ end
3535
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
3636

3737
@testset for ad in [
38-
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
38+
ADTypes.AutoForwardDiff(),
39+
ADTypes.AutoReverseDiff(),
40+
ADTypes.AutoZygote(),
41+
ADTypes.AutoEnzyme(),
3942
]
4043
q_true = MeanFieldGaussian(
4144
Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true)))
@@ -44,7 +47,7 @@ end
4447
obj = RepGradELBO(10; entropy=StickingTheLandingEntropy())
4548
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))
4649

47-
aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true)
50+
aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad)
4851
AdvancedVI.value_and_gradient!(
4952
ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
5053
)

test/models/normal.jl

-2
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
3535

3636
σ0 = realtype(0.3)
3737
μ = Fill(realtype(5), n_dims)
38-
#randn(rng, realtype, n_dims)
3938
σ = Fill(σ0, n_dims)
40-
#log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
4139

4240
model = TestNormal(μ, Diagonal.^ 2))
4341

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using DistributionsAD
2020
using LogDensityProblems
2121
using Optimisers
2222
using ADTypes
23-
using Enzyme, ForwardDiff, ReverseDiff, Zygote
23+
using ForwardDiff, ReverseDiff, Zygote, Enzyme
2424

2525
using AdvancedVI
2626

0 commit comments

Comments
 (0)