Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to DifferentiationInterface #98

Merged
merged 17 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "0.3.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -24,52 +24,43 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
ForwardDiff = "0.10"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2"
Zygote = "0.6.63"
Zygote = "0.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Test"]
40 changes: 2 additions & 38 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,17 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using AdvancedVI: ADTypes
else
using ..Enzyme
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..AdvancedVI: ADTypes
end

function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, params)
return restructure(params)::typeof(restructure.model)
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
Enzyme.Const(aux),
)
DiffResults.value!(out, y)
return out
end

end
42 changes: 0 additions & 42 deletions ext/AdvancedVIForwardDiffExt.jl

This file was deleted.

36 changes: 0 additions & 36 deletions ext/AdvancedVIReverseDiffExt.jl

This file was deleted.

37 changes: 0 additions & 37 deletions ext/AdvancedVITapirExt.jl

This file was deleted.

36 changes: 0 additions & 36 deletions ext/AdvancedVIZygoteExt.jl

This file was deleted.

23 changes: 3 additions & 20 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,14 @@ using LinearAlgebra

using LogDensityProblems

using ADTypes, DiffResults
using ADTypes
using DifferentiationInterface
using ChainRulesCore

using FillArrays

using StatsBase

# derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)

Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
`f` may receive auxiliary input as `f(x,aux)`.

# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end

"""
restructure_ad_forward(adtype, restructure, params)

Expand Down Expand Up @@ -131,15 +115,14 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient(rng, obj, adtype, prob, λ, restructure, obj_state)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `λ`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
Expand Down
10 changes: 5 additions & 5 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ function estimate_repgradelbo_ad_forward(params′, aux)
return -elbo
end

function estimate_gradient!(
function estimate_gradient(
rng::Random.AbstractRNG,
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
prob,
params,
restructure,
Expand All @@ -120,8 +119,9 @@ function estimate_gradient!(
restructure=restructure,
q_stop=q_stop,
)
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
nelbo = DiffResults.value(out)
nelbo, g = value_and_gradient(
estimate_repgradelbo_ad_forward, adtype, params, Constant(aux)
)
stat = (elbo=-nelbo,)
return out, nothing, stat
return g, nothing, stat
end
14 changes: 2 additions & 12 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,16 @@ function optimize(
opt_st = maybe_init_optimizer(state_init, optimizer, params)
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
avg_st = maybe_init_averager(state_init, averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
stats = NamedTuple[]

for t in 1:max_iter
stat = (iteration=t,)

grad_buf, obj_st, stat′ = estimate_gradient!(
rng,
objective,
adtype,
grad_buf,
problem,
params,
restructure,
obj_st,
objargs...,
grad, obj_st, stat′ = estimate_gradient(
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
rng, objective, adtype, problem, params, restructure, obj_st, objargs...
)
stat = merge(stat, stat′)

grad = DiffResults.gradient(grad_buf)
opt_st, params = update_variational_params!(
typeof(q_init), opt_st, params, restructure, grad
)
Expand Down
Loading
Loading