From 26515e411d2150a00badd398a9ee1964f3dfa147 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 7 Aug 2024 08:54:50 +0200 Subject: [PATCH 1/3] Fix Enzyme extension --- ext/NormalizingFlowsEnzymeExt.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl index 1b59cad8..a00c864a 100644 --- a/ext/NormalizingFlowsEnzymeExt.jl +++ b/ext/NormalizingFlowsEnzymeExt.jl @@ -10,16 +10,14 @@ else using ..NormalizingFlows: ADTypes, DiffResults end -# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) function NormalizingFlows.value_and_gradient!( ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult ) where {T<:Real} - y = f(θ) - DiffResults.value!(out, y) ∇θ = DiffResults.gradient(out) fill!(∇θ, zero(T)) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + _, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + DiffResults.value!(out, y) return out end -end \ No newline at end of file +end From 273427be4f44229002c45d4c2d0ff813fe987d16 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 7 Aug 2024 08:56:49 +0200 Subject: [PATCH 2/3] Enable Enzyme test --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index a394d806..e700a250 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -25,7 +25,7 @@ end ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # not working now + ADTypes.AutoEnzyme(), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -49,4 +49,4 @@ end @test all(DiffResults.gradient(out) .!= nothing) end end -end \ No newline at end of file +end From fb5c48b39f55f029b0e4cc6e54a67979554ed30d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 7 Aug 2024 09:12:40 +0200 Subject: [PATCH 3/3] Add workaround for Enzyme --- test/ad.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/ad.jl b/test/ad.jl index e700a250..aab6841b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -41,9 +41,17 @@ end out = DiffResults.GradientResult(θ) # check grad computation for elbo + # Enzyme needs a workaround + if at isa ADTypes.AutoEnzyme + activity = Enzyme.API.runtimeActivity() + Enzyme.API.runtimeActivity!(true) + end NormalizingFlows.grad!( Random.default_rng(), at, elbo, θ, re, out, logp, sample_per_iter ) + if at isa ADTypes.AutoEnzyme + Enzyme.API.runtimeActivity!(activity) + end @test DiffResults.value(out) != nothing @test all(DiffResults.gradient(out) .!= nothing)