From ee4bd5433d2d1b9b80bbf2e81619a2e731cc031c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 22:47:18 -0700 Subject: [PATCH 01/16] migrate to DifferentiationInterface --- Project.toml | 23 +++------- ext/AdvancedVIEnzymeExt.jl | 40 +---------------- ext/AdvancedVIForwardDiffExt.jl | 42 ----------------- ext/AdvancedVIReverseDiffExt.jl | 36 --------------- ext/AdvancedVITapirExt.jl | 37 --------------- ext/AdvancedVIZygoteExt.jl | 36 --------------- src/AdvancedVI.jl | 23 ++-------- src/objectives/elbo/repgradelbo.jl | 10 ++--- src/optimize.jl | 5 +-- test/Project.toml | 6 ++- test/inference/repgradelbo_distributionsad.jl | 6 +-- test/inference/repgradelbo_locationscale.jl | 6 +-- .../repgradelbo_locationscale_bijectors.jl | 6 +-- test/interface/ad.jl | 45 ------------------- test/interface/repgradelbo.jl | 11 +++-- test/runtests.jl | 11 ++--- 16 files changed, 42 insertions(+), 301 deletions(-) delete mode 100644 ext/AdvancedVIForwardDiffExt.jl delete mode 100644 ext/AdvancedVIReverseDiffExt.jl delete mode 100644 ext/AdvancedVITapirExt.jl delete mode 100644 ext/AdvancedVIZygoteExt.jl delete mode 100644 test/interface/ad.jl diff --git a/Project.toml b/Project.toml index 6322bfa7..66b5cb08 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -24,52 +24,43 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" -AdvancedVIForwardDiffExt = "ForwardDiff" -AdvancedVIReverseDiffExt = "ReverseDiff" -AdvancedVITapirExt = "Tapir" -AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2, 1" Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" -DiffResults = "1" +DifferentiationInterface = "0.6" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" Enzyme = "0.13" FillArrays = "1.3" -ForwardDiff = "0.10.36" +ForwardDiff = "0.10" Functors = "0.4" LinearAlgebra = "1" LogDensityProblems = "2" +Mooncake = "0.4" Optimisers = "0.2.16, 0.3" ProgressMeter = "1.6" Random = "1" Requires = "1.0" -ReverseDiff = "1.15.1" +ReverseDiff = "1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -Tapir = "0.2" -Zygote = "0.6.63" +Zygote = "0.6" julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 3b68d531..45eaa329 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -1,53 +1,17 @@ - module AdvancedVIEnzymeExt if isdefined(Base, :get_extension) using Enzyme using AdvancedVI - using AdvancedVI: ADTypes, DiffResults + using AdvancedVI: ADTypes else using ..Enzyme using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults + using ..AdvancedVI: ADTypes end function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params) return restructure(params)::typeof(restructure.model) end -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - ∇x = DiffResults.gradient(out) - fill!(∇x, zero(eltype(∇x))) - _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), - Enzyme.Const(f), - Enzyme.Active, - Enzyme.Duplicated(x, ∇x), - ) - DiffResults.value!(out, y) - return out -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoEnzyme, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - ∇x = DiffResults.gradient(out) - fill!(∇x, zero(eltype(∇x))) - _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), - Enzyme.Const(f), - Enzyme.Active, - Enzyme.Duplicated(x, ∇x), - Enzyme.Const(aux), - ) - DiffResults.value!(out, y) - return out -end - end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl deleted file mode 100644 index 6904fa7a..00000000 --- a/ext/AdvancedVIForwardDiffExt.jl +++ /dev/null @@ -1,42 +0,0 @@ - -module AdvancedVIForwardDiffExt - -if isdefined(Base, :get_extension) - using ForwardDiff - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults -else - using ..ForwardDiff - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults -end - -getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, - f, - x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, -) - chunk_size = getchunksize(ad) - config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, x) - else - ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) - end - ForwardDiff.gradient!(out, f, x, config) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, - f, - x::AbstractVector, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl deleted file mode 100644 index 9cde91a1..00000000 --- a/ext/AdvancedVIReverseDiffExt.jl +++ /dev/null @@ -1,36 +0,0 @@ - -module AdvancedVIReverseDiffExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using ReverseDiff -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..ReverseDiff -end - -# ReverseDiff without compiled tape -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, - f, - x::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, -) - tp = ReverseDiff.GradientTape(f, x) - ReverseDiff.gradient!(out, tp, x) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl deleted file mode 100644 index 459ef7da..00000000 --- a/ext/AdvancedVITapirExt.jl +++ /dev/null @@ -1,37 +0,0 @@ -module AdvancedVITapirExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using Tapir -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..Tapir -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - rule = Tapir.build_rrule(f, x) - y, g = Tapir.value_and_gradient!!(rule, f, x) - DiffResults.value!(out, y) - DiffResults.gradient!(out, last(g)) - return out -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoTapir, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - rule = Tapir.build_rrule(f, x, aux) - y, g = Tapir.value_and_gradient!!(rule, f, x, aux) - DiffResults.value!(out, y) - DiffResults.gradient!(out, g[2]) - return out -end - -end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl deleted file mode 100644 index 2cdd8392..00000000 --- a/ext/AdvancedVIZygoteExt.jl +++ /dev/null @@ -1,36 +0,0 @@ - -module AdvancedVIZygoteExt - -if isdefined(Base, :get_extension) - using AdvancedVI - using AdvancedVI: ADTypes, DiffResults - using ChainRulesCore - using Zygote -else - using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults - using ..ChainRulesCore - using ..Zygote -end - -function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) - y, back = Zygote.pullback(f, x) - ∇x = back(one(y)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇x)) - return out -end - -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, - f, - x::AbstractVector{<:Real}, - aux, - out::DiffResults.MutableDiffResult, -) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) -end - -end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5402e075..1aaee9c5 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -16,30 +16,14 @@ using LinearAlgebra using LogDensityProblems -using ADTypes, DiffResults +using ADTypes +using DifferentiationInterface using ChainRulesCore using FillArrays using StatsBase -# derivatives -""" - value_and_gradient!(ad, f, x, out) - value_and_gradient!(ad, f, x, aux, out) - -Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. -`f` may receive auxiliary input as `f(x,aux)`. - -# Arguments -- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. -- `f`: Function subject to differentiation. -- `x`: The point to evaluate the gradient. -- `aux`: Auxiliary input passed to `f`. -- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. -""" -function value_and_gradient! end - """ restructure_ad_forward(adtype, restructure, params) @@ -131,7 +115,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) + estimate_gradient(rng, obj, adtype, prob, λ, restructure, obj_state) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -139,7 +123,6 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `λ`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `λ`. diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index e6f04ae8..f13389ce 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -101,11 +101,10 @@ function estimate_repgradelbo_ad_forward(params′, aux) return -elbo end -function estimate_gradient!( +function estimate_gradient( rng::Random.AbstractRNG, obj::RepGradELBO, adtype::ADTypes.AbstractADType, - out::DiffResults.MutableDiffResult, prob, params, restructure, @@ -120,8 +119,9 @@ function estimate_gradient!( restructure=restructure, q_stop=q_stop, ) - value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out) - nelbo = DiffResults.value(out) + nelbo, g = value_and_gradient( + estimate_repgradelbo_ad_forward, adtype, params, Constant(aux) + ) stat = (elbo=-nelbo,) - return out, nothing, stat + return g, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index eb462ff5..1d324988 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -68,17 +68,15 @@ function optimize( opt_st = maybe_init_optimizer(state_init, optimizer, params) obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) avg_st = maybe_init_averager(state_init, averager, params) - grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] for t in 1:max_iter stat = (iteration=t,) - grad_buf, obj_st, stat′ = estimate_gradient!( + grad, obj_st, stat′ = estimate_gradient( rng, objective, adtype, - grad_buf, problem, params, restructure, @@ -87,7 +85,6 @@ function optimize( ) stat = merge(stat, stat′) - grad = DiffResults.gradient(grad_buf) opt_st, params = update_variational_params!( typeof(q_init), opt_st, params, restructure, grad ) diff --git a/test/Project.toml b/test/Project.toml index ca0fc384..4f23ba93 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,7 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -26,9 +26,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" Bijectors = "0.13" -DiffResults = "1.0" +DifferentiationInterface = "0.6" Distributions = "0.25.111" DistributionsAD = "0.6.45" +Enzyme = "0.13" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" @@ -41,6 +42,7 @@ ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +StatsBase = "0.34" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 94da09bc..8970056a 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -5,12 +5,12 @@ AD_distributionsad = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_distributionsad[:Enzyme] = AutoEnzyme() + AD_distributionsad[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) end @testset "inference RepGradELBO DistributionsAD" begin diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 9e254b6a..b67f584c 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -5,12 +5,12 @@ AD_locationscale = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme() + AD_locationscale[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) end @testset "inference RepGradELBO VILocationScale" begin diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 731326f3..6cb2e45c 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -5,12 +5,12 @@ AD_locationscale_bijectors = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() + AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) end @testset "inference RepGradELBO VILocationScale Bijectors" begin diff --git a/test/interface/ad.jl b/test/interface/ad.jl deleted file mode 100644 index e8f4da4e..00000000 --- a/test/interface/ad.jl +++ /dev/null @@ -1,45 +0,0 @@ - -using Test - -const interface_ad_backends = Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), -) - -if @isdefined(Tapir) - interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false) -end - -if @isdefined(Enzyme) - interface_ad_backends[:Enzyme] = AutoEnzyme() -end - -@testset "ad" begin - @testset "$(adname)" for (adname, adtype) in interface_ad_backends - D = 10 - A = randn(D, D) - λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) - f(λ′) = λ′' * A * λ′ / 2 - AdvancedVI.value_and_gradient!(adtype, f, λ, grad_buf) - ∇ = DiffResults.gradient(grad_buf) - f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A') * λ / 2 - @test f ≈ λ' * A * λ / 2 - end - - @testset "$(adname) with auxiliary input" for (adname, adtype) in interface_ad_backends - D = 10 - A = randn(D, D) - λ = randn(D) - b = randn(D) - grad_buf = DiffResults.GradientResult(λ) - f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - AdvancedVI.value_and_gradient!(adtype, f, λ, (b=b,), grad_buf) - ∇ = DiffResults.gradient(grad_buf) - f = DiffResults.value(grad_buf) - @test ∇ ≈ (A + A') * λ / 2 + b - @test f ≈ λ' * A * λ / 2 + dot(b, λ) - end -end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index baf1499a..9c335200 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -37,11 +37,11 @@ end ad_backends = [ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] - if @isdefined(Tapir) - push!(ad_backends, AutoTapir(; safe_mode=false)) + if @isdefined(Mooncake) + push!(ad_backends, AutoMooncake(; config=nothing)) end if @isdefined(Enzyme) - push!(ad_backends, AutoEnzyme()) + push!(ad_backends, AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const)) end @testset for ad in ad_backends @@ -53,10 +53,9 @@ end out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) - AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + grad = value_and_gradient( + AdvancedVI.estimate_repgradelbo_ad_forward, ad, params, Constant(aux) ) - grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 end end diff --git a/test/runtests.jl b/test/runtests.jl index 43958e8e..4b76b2e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,8 @@ using Bijectors using Distributions using FillArrays using LinearAlgebra +using LogDensityProblems +using Optimisers using PDMats using Pkg using Random, StableRNGs @@ -18,14 +20,14 @@ using Functors using DistributionsAD @functor TuringDiagMvNormal -using LogDensityProblems -using Optimisers + using ADTypes +using DifferentiationInterface using ForwardDiff, ReverseDiff, Zygote if VERSION >= v"1.10" - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake using Enzyme end @@ -47,7 +49,6 @@ include("models/normallognormal.jl") # Tests if GROUP == "All" || GROUP == "Interface" - include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") include("interface/rules.jl") From c7b300f683c15ea55112a9c5fa77b2dc94dd34f7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 22:51:51 -0700 Subject: [PATCH 02/16] run formatter --- src/optimize.jl | 9 +-------- test/inference/repgradelbo_distributionsad.jl | 4 +++- test/inference/repgradelbo_locationscale.jl | 4 +++- test/inference/repgradelbo_locationscale_bijectors.jl | 4 +++- test/interface/repgradelbo.jl | 7 ++++++- test/runtests.jl | 1 - 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 1d324988..00e49109 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -74,14 +74,7 @@ function optimize( stat = (iteration=t,) grad, obj_st, stat′ = estimate_gradient( - rng, - objective, - adtype, - problem, - params, - restructure, - obj_st, - objargs..., + rng, objective, adtype, problem, params, restructure, obj_st, objargs... ) stat = merge(stat, stat′) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 8970056a..4086a205 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -10,7 +10,9 @@ if @isdefined(Mooncake) end if @isdefined(Enzyme) - AD_distributionsad[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) + AD_distributionsad[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference RepGradELBO DistributionsAD" begin diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index b67f584c..87c626f8 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -10,7 +10,9 @@ if @isdefined(Mooncake) end if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) + AD_locationscale[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference RepGradELBO VILocationScale" begin diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 6cb2e45c..167fe389 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -10,7 +10,9 @@ if @isdefined(Mooncake) end if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const) + AD_locationscale_bijectors[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference RepGradELBO VILocationScale Bijectors" begin diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 9c335200..bd698152 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -41,7 +41,12 @@ end push!(ad_backends, AutoMooncake(; config=nothing)) end if @isdefined(Enzyme) - push!(ad_backends, AutoEnzyme(; mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const)) + push!( + ad_backends, + AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ), + ) end @testset for ad in ad_backends diff --git a/test/runtests.jl b/test/runtests.jl index 4b76b2e0..322517a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,6 @@ using Functors using DistributionsAD @functor TuringDiagMvNormal - using ADTypes using DifferentiationInterface using ForwardDiff, ReverseDiff, Zygote From c4e2db4edbdc60e275d2c19a6d2afbabf8fa9707 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 22:53:58 -0700 Subject: [PATCH 03/16] tighten compat bound for ADTypes --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 66b5cb08..6932b4c1 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" [compat] -ADTypes = "0.1, 0.2, 1" +ADTypes = "1" Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" From c2593c270f8ef0d236001804c70b896084d1b83e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 22:55:40 -0700 Subject: [PATCH 04/16] fix compat bound for docs --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index f42b21bc..8dc25a3b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] -ADTypes = "0.1.6" +ADTypes = "1" AdvancedVI = "0.3" Bijectors = "0.13.6" Distributions = "0.25" From a8d2ee90237c1063596d1978971702319a635145 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:02:48 -0700 Subject: [PATCH 05/16] add weakdeps in extras too for Julia 1.9< --- Project.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 6932b4c1..739e5e57 100644 --- a/Project.toml +++ b/Project.toml @@ -59,6 +59,11 @@ julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 32a2fe09a6fb5fe1dac7b64cf1d0d3de70f7be08 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:05:07 -0700 Subject: [PATCH 06/16] add basic buildkite pipeline --- .buildkite/pipeline.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .buildkite/pipeline.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 00000000..cf5dcdf3 --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,18 @@ +steps: + - label: "CUDA with julia {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + agents: + queue: "juliagpu" + cuda: "*" + timeout_in_minutes: 60 + env: + GROUP: "GPU" + ADVANCEDVI_TEST_CUDA: "true" + matrix: + setup: + julia: + - "1.10" From 9f6fbac97df3cfc4803b280e05bed29b1db6873a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:09:26 -0700 Subject: [PATCH 07/16] fix remove Enzyme dependency in test for Julia 1.7 --- test/Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 4f23ba93..290183b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,7 +4,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -29,7 +28,6 @@ Bijectors = "0.13" DifferentiationInterface = "0.6" Distributions = "0.25.111" DistributionsAD = "0.6.45" -Enzyme = "0.13" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" From c102f80ddb82101a4a945d12b3da390403679d41 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:18:32 -0700 Subject: [PATCH 08/16] revert "add weakdeps in extras too for Julia 1.9<" --- Project.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Project.toml b/Project.toml index 739e5e57..6932b4c1 100644 --- a/Project.toml +++ b/Project.toml @@ -59,11 +59,6 @@ julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From b765d48584ca30fe35924a0468c1cfc963cfe8f7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:28:19 -0700 Subject: [PATCH 09/16] add back extras --- Project.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 6932b4c1..739e5e57 100644 --- a/Project.toml +++ b/Project.toml @@ -59,6 +59,11 @@ julia = "1.7" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 6f1e98995e3221b2e6bace7680e3e2fc346e59f1 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:31:13 -0700 Subject: [PATCH 10/16] add empty exts because Julia 1.7 wants them --- ext/AdvancedVIForwardDiffExt.jl | 0 ext/AdvancedVIMooncakeExt.jl | 0 ext/AdvancedVIReverseDiffExt.jl | 0 ext/AdvancedVIZygoteExt.jl | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 ext/AdvancedVIForwardDiffExt.jl create mode 100644 ext/AdvancedVIMooncakeExt.jl create mode 100644 ext/AdvancedVIReverseDiffExt.jl create mode 100644 ext/AdvancedVIZygoteExt.jl diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 00000000..e69de29b diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl new file mode 100644 index 00000000..e69de29b diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 00000000..e69de29b diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 00000000..e69de29b From 90a865c1c86d450b13cbf1a6a4d5c50c100b9016 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Sep 2024 23:33:19 -0700 Subject: [PATCH 11/16] fix missing Enzyme bug --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 322517a4..064c533e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ using ForwardDiff, ReverseDiff, Zygote if VERSION >= v"1.10" Pkg.add("Mooncake") + Pkg.add("Enzyme") using Mooncake using Enzyme end From 61c499977ee1fce7c6c74ffeaedf803ce0ebccab Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Sep 2024 22:18:54 -0700 Subject: [PATCH 12/16] revert to old AD interface, redirect to DI except for Enzyme --- ext/AdvancedVIEnzymeExt.jl | 12 ++++++++++++ src/AdvancedVI.jl | 20 ++++++++++++++++++++ src/objectives/elbo/repgradelbo.jl | 4 +--- test/interface/repgradelbo.jl | 8 ++++---- test/runtests.jl | 2 +- 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 45eaa329..6f020453 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -14,4 +14,16 @@ function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, pa return restructure(params)::typeof(restructure.model) end +function AdvancedVI.value_and_gradient(::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, aux) + ∇x = zero(x) + _, y = Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), + Enzyme.Const(f), + Enzyme.Active, + Enzyme.Duplicated(x, ∇x), + Enzyme.Const(aux), + ) + return y, ∇x +end + end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1aaee9c5..f1955003 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -24,6 +24,26 @@ using FillArrays using StatsBase +# Derivatives +""" + value_and_gradient(ad, f, x, aux, out) + +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +`f` may receive auxiliary input as `f(x,aux)`. + +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. + +# Returns +- `value`: `f` evaluated at `x`. +- `grad`: Gradient of `f` evaluated at `x`. +""" +value_and_gradient(ad::ADTypes.AbstractADType, f, x, aux) = + DifferentiationInterface.value_and_gradient(f, ad, x, Constant(aux)) + """ restructure_ad_forward(adtype, restructure, params) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index f13389ce..8accfe44 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -119,9 +119,7 @@ function estimate_gradient( restructure=restructure, q_stop=q_stop, ) - nelbo, g = value_and_gradient( - estimate_repgradelbo_ad_forward, adtype, params, Constant(aux) - ) + nelbo, g = value_and_gradient(adtype, estimate_repgradelbo_ad_forward, params, aux) stat = (elbo=-nelbo,) return g, nothing, stat end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index bd698152..01bb4e9a 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -49,7 +49,7 @@ end ) end - @testset for ad in ad_backends + @testset for adtype in ad_backends q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) @@ -57,9 +57,9 @@ end obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) - grad = value_and_gradient( - AdvancedVI.estimate_repgradelbo_ad_forward, ad, params, Constant(aux) + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype) + grad = AdvancedVI.value_and_gradient( + adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux ) @test norm(grad) ≈ 0 atol = 1e-5 end diff --git a/test/runtests.jl b/test/runtests.jl index 064c533e..20190922 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,7 +21,6 @@ using DistributionsAD @functor TuringDiagMvNormal using ADTypes -using DifferentiationInterface using ForwardDiff, ReverseDiff, Zygote if VERSION >= v"1.10" @@ -49,6 +48,7 @@ include("models/normallognormal.jl") # Tests if GROUP == "All" || GROUP == "Interface" + include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") include("interface/rules.jl") From 169b368a0379ad643700a8e6183d1a135913c870 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Sep 2024 22:19:31 -0700 Subject: [PATCH 13/16] add missing test file --- test/interface/ad.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 test/interface/ad.jl diff --git a/test/interface/ad.jl b/test/interface/ad.jl new file mode 100644 index 00000000..f6bccbb8 --- /dev/null +++ b/test/interface/ad.jl @@ -0,0 +1,29 @@ + +using Test + +const interface_ad_backends = Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), +) + +if @isdefined(Tapir) + interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false) +end + +if @isdefined(Enzyme) + interface_ad_backends[:Enzyme] = AutoEnzyme() +end + +@testset "ad" begin + @testset "$(adname)" for (adname, adtype) in interface_ad_backends + D = 10 + A = randn(D, D) + λ = randn(D) + b = randn(D) + f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) + fval, grad = AdvancedVI.value_and_gradient(adtype, f, λ, (b=b,)) + @test grad ≈ (A + A') * λ / 2 + b + @test fval ≈ λ' * A * λ / 2 + dot(b, λ) + end +end From 36c70e7e32e85d8f6fc31095205e5aee98826992 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Sep 2024 23:08:32 -0700 Subject: [PATCH 14/16] merge upstream, revert to DiffResults gradient interface --- Project.toml | 6 ++-- ext/AdvancedVIEnzymeExt.jl | 18 ++++++++---- src/AdvancedVI.jl | 21 ++++++++------ src/objectives/elbo/repgradelbo.jl | 8 +++-- src/optimize.jl | 16 ++++++++-- test/Project.toml | 2 ++ test/inference/repgradelbo_locationscale.jl | 2 +- .../scoregradelbo_distributionsad.jl | 10 +++---- test/inference/scoregradelbo_locationscale.jl | 10 ++++--- .../scoregradelbo_locationscale_bijectors.jl | 8 ++--- test/interface/ad.jl | 9 ++++-- test/interface/repgradelbo.jl | 5 ++-- test/interface/scoregradelbo.jl | 29 ------------------- test/runtests.jl | 1 + 14 files changed, 75 insertions(+), 70 deletions(-) diff --git a/Project.toml b/Project.toml index 739e5e57..572ea144 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -37,6 +38,7 @@ ADTypes = "1" Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" +DiffResults = "1" DifferentiationInterface = "0.6" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" @@ -62,10 +64,10 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 6f020453..a4119c3b 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -3,19 +3,26 @@ module AdvancedVIEnzymeExt if isdefined(Base, :get_extension) using Enzyme using AdvancedVI - using AdvancedVI: ADTypes + using AdvancedVI: ADTypes, DiffResults else using ..Enzyme using ..AdvancedVI - using ..AdvancedVI: ADTypes + using ..AdvancedVI: ADTypes, DiffResults end function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params) return restructure(params)::typeof(restructure.model) end -function AdvancedVI.value_and_gradient(::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, aux) - ∇x = zero(x) +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoEnzyme, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult, +) + ∇x = DiffResults.gradient(out) + fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), Enzyme.Const(f), @@ -23,7 +30,8 @@ function AdvancedVI.value_and_gradient(::ADTypes.AutoEnzyme, f, x::AbstractVecto Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux), ) - return y, ∇x + DiffResults.value!(out, y) + return out end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a71e53c7..38ca3551 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -17,6 +17,7 @@ using LinearAlgebra using LogDensityProblems using ADTypes +using DiffResults using DifferentiationInterface using ChainRulesCore @@ -26,7 +27,7 @@ using StatsBase # Derivatives """ - value_and_gradient(ad, f, x, aux, out) + value_and_gradient!(ad, f, x, aux, out) Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. `f` may receive auxiliary input as `f(x,aux)`. @@ -36,13 +37,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. - -# Returns -- `value`: `f` evaluated at `x`. -- `grad`: Gradient of `f` evaluated at `x`. +- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ -value_and_gradient(ad::ADTypes.AbstractADType, f, x, aux) = - DifferentiationInterface.value_and_gradient(f, ad, x, Constant(aux)) +function value_and_gradient!(ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult) + grad_buf = DiffResults.gradient(out) + y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) + DiffResults.value!(out, y) + return out +end """ restructure_ad_forward(adtype, restructure, params) @@ -135,7 +137,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient(rng, obj, adtype, prob, λ, restructure, obj_state) + estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -143,8 +145,9 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `λ`: Variational parameters to evaluate the gradient on. +- `params`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `λ`. - `obj_state`: Previous state of the objective. diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 4d9650fe..b8bf63fa 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -94,10 +94,11 @@ function estimate_repgradelbo_ad_forward(params′, aux) return -elbo end -function estimate_gradient( +function estimate_gradient!( rng::Random.AbstractRNG, obj::RepGradELBO, adtype::ADTypes.AbstractADType, + out::DiffResults.MutableDiffResult, prob, params, restructure, @@ -112,7 +113,8 @@ function estimate_gradient( restructure=restructure, q_stop=q_stop, ) - nelbo, g = value_and_gradient(adtype, estimate_repgradelbo_ad_forward, params, aux) + value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out) + nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - return g, nothing, stat + return out, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index 5b58ef73..8ef9db76 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -42,7 +42,7 @@ The arguments are as follows: - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. - `gradient`: The estimated (possibly stochastic) gradient. -`cb` can return a `NamedTuple` containing some additional information computed within `cb`. +`callback` can return a `NamedTuple` containing some additional information computed within `cb`. This will be appended to the statistic of the current corresponding iteration. Otherwise, just return `nothing`. @@ -68,15 +68,25 @@ function optimize( opt_st = maybe_init_optimizer(state_init, optimizer, params) obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) avg_st = maybe_init_averager(state_init, averager, params) + grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] for t in 1:max_iter stat = (iteration=t,) - grad, obj_st, stat′ = estimate_gradient( - rng, objective, adtype, problem, params, restructure, obj_st, objargs... + grad_buf, obj_st, stat′ = estimate_gradient!( + rng, + objective, + adtype, + grad_buf, + problem, + params, + restructure, + obj_st, + objargs..., ) stat = merge(stat, stat′) + grad = DiffResults.gradient(grad_buf) opt_st, params = update_variational_params!( typeof(q_init), opt_st, params, restructure, grad ) diff --git a/test/Project.toml b/test/Project.toml index 290183b1..bbf9c4c6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -25,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" Bijectors = "0.13" +DiffResults = "1" DifferentiationInterface = "0.6" Distributions = "0.25.111" DistributionsAD = "0.6.45" diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 87c626f8..1ca31885 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -15,7 +15,7 @@ if @isdefined(Enzyme) ) end -@testset "inference RepGradELBO VILocationScale" begin +@testset "inference ScoreGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 700dda6d..1de7af1d 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -1,19 +1,19 @@ -AD_distributionsad = Dict( +AD_scoregradelbo_distributionsad = Dict( :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment :Zygote => AutoZygote(), ) if @isdefined(Tapir) - AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) + AD_scoregradelbo_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) end #if @isdefined(Enzyme) -# AD_distributionsad[:Enzyme] = AutoEnzyme() +# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme() #end -@testset "inference RepGradELBO DistributionsAD" begin +@testset "inference ScoreGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), @@ -23,7 +23,7 @@ end :ScoreGradELBOStickingTheLanding => ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_distributionsad + (adbackname, adtype) in AD_scoregradelbo_distributionsad seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index ef49713b..f0073d7c 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -1,16 +1,18 @@ -AD_locationscale = Dict( +AD_scoregradelbo_locationscale = Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing) end if @isdefined(Enzyme) - AD_locationscale[:Enzyme] = AutoEnzyme() + AD_scoregradelbo_locationscale[:Enzyme] = AutoEnzyme(; + mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const + ) end @testset "inference ScoreGradELBO VILocationScale" begin diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 088130aa..bee8234a 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -1,16 +1,16 @@ -AD_locationscale_bijectors = Dict( +AD_scoregradelbo_locationscale_bijectors = Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), ) #if @isdefined(Tapir) -# AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) #end if @isdefined(Enzyme) - AD_locationscale_bijectors[:Enzyme] = AutoEnzyme() + AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme() end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin @@ -24,7 +24,7 @@ end :ScoreGradELBOStickingTheLanding => ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()), ), - (adbackname, adtype) in AD_locationscale_bijectors + (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index f6bccbb8..713a0f56 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -21,9 +21,12 @@ end A = randn(D, D) λ = randn(D) b = randn(D) + grad_buf = DiffResults.GradientResult(λ) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - fval, grad = AdvancedVI.value_and_gradient(adtype, f, λ, (b=b,)) - @test grad ≈ (A + A') * λ / 2 + b - @test fval ≈ λ' * A * λ / 2 + dot(b, λ) + AdvancedVI.value_and_gradient!(adtype, f, λ, (b=b,), grad_buf) + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A') * λ / 2 + b + @test f ≈ λ' * A * λ / 2 + dot(b, λ) end end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 01bb4e9a..faad924d 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -58,9 +58,10 @@ end out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype) - grad = AdvancedVI.value_and_gradient( - adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux + AdvancedVI.value_and_gradient!( + adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) + grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 end end diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl index a800f744..8a6ebb14 100644 --- a/test/interface/scoregradelbo.jl +++ b/test/interface/scoregradelbo.jl @@ -26,32 +26,3 @@ using Test @test elbo ≈ elbo_ref rtol = 0.2 end end - -@testset "interface ScoreGradELBO STL variance reduction" begin - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = normal_meanfield(rng, Float64) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - @testset for ad in [ - ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() - ] - q_true = MeanFieldGaussian( - Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) - ) - params, re = Optimisers.destructure(q_true) - obj = ScoreGradELBO( - 1000; entropy=StickingTheLandingEntropy(), baseline_history=[0.0] - ) - out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) - AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_scoregradelbo_ad_forward, params, aux, out - ) - value = DiffResults.value(out) - grad = DiffResults.gradient(out) - @test norm(grad) ≈ 0 atol = 10 # high tolerance required. - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 19a03549..7c0e3129 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Test: @testset, @test using Base.Iterators using Bijectors +using DiffResults using Distributions using FillArrays using LinearAlgebra From dce99d7cda062713d184713bf03b5a0762507ab4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 30 Sep 2024 02:09:38 -0400 Subject: [PATCH 15/16] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface/repgradelbo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index faad924d..afd6249e 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -57,7 +57,9 @@ end obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype) + aux = ( + rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype + ) AdvancedVI.value_and_gradient!( adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) From 79e5c629fcd30fdd22013d35cd06a125a656218d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 30 Sep 2024 02:11:50 -0400 Subject: [PATCH 16/16] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/AdvancedVI.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 38ca3551..aebe765e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -39,7 +39,9 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif - `aux`: Auxiliary input passed to `f`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ -function value_and_gradient!(ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult) +function value_and_gradient!( + ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult +) grad_buf = DiffResults.gradient(out) y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) DiffResults.value!(out, y)