Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Tapir support #71

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- '1.6'
#- '1.7'
- '1.10'
Copy link
Member

@yebai yebai Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, can you adapt the Bijectors setup so we don't need to comment out 1.7?

os:
- ubuntu-latest
- macOS-latest
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
Expand All @@ -55,15 +57,17 @@ Requires = "1.0"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2.34"
Zygote = "0.6.63"
julia = "1.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
6 changes: 4 additions & 2 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
::Any,
f,
x::AbstractVector{<:Real},
x::AbstractVector,
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
Expand All @@ -31,12 +32,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
st_ad,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
6 changes: 4 additions & 2 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
::ADTypes.AutoReverseDiff,
::Any,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
Expand All @@ -25,12 +26,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
47 changes: 47 additions & 0 deletions ext/AdvancedVITapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

module AdvancedVITapirExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Tapir
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Tapir
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
yebai marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DiffResults.gradient!(out, last(g))
DiffResults.gradient!(out, g[2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, to clarify, we don't need this change. Is that correct?

return out
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x, aux)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
return out
end

end
9 changes: 7 additions & 2 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ else
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
::ADTypes.AutoZygote,
::Any,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
Expand All @@ -25,12 +29,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
35 changes: 22 additions & 13 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ using StatsBase

# derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)
value_and_gradient!(adtype, ad_st, f, x, out)
value_and_gradient!(adtype, ad_st, f, x, aux, out)

Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation (AD) backend `ad` and store the result in `out`.
`f` may receive auxiliary input as `f(x,aux)`.

# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `adtype::ADTypes.AbstractADType`: AD backend.
- `ad_st`: State used by the AD backend. (This will often be pre-compiled tapes/caches.)
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
Expand All @@ -41,18 +42,22 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
function value_and_gradient! end

"""
stop_gradient(x)
init_adbackend(adtype, f, x)
init_adbackend(adtype, f, x, aux)

Stop the gradient from propagating to `x` if the selected ad backend supports it.
Otherwise, it is equivalent to `identity`.
Initialize the AD backend and setup states necessary.

# Arguments
- `x`: Input
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.

# Returns
- `x`: Same value as the input.
- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.)
"""
function stop_gradient end
init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing
init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing

# Update for gradient descent step
"""
Expand Down Expand Up @@ -96,18 +101,22 @@ If the estimator is stateful, it can implement `init` to initialize the state.
abstract type AbstractVariationalObjective end

"""
init(rng, obj, prob, params, restructure)
init(rng, obj, adtype, prob, params, restructure)

Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
Initialize a state of the variational objective `obj`.
This function needs to be implemented only if `obj` is stateful.
The state of the AD backend `adtype` shall also be initialized here.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.ADType`:Automatic differentiation backend.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any, ::Any) =
nothing

"""
estimate_objective([rng,] obj, q, prob; kwargs...)
Expand Down
19 changes: 17 additions & 2 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ function estimate_repgradelbo_ad_forward(params′, aux)
return -elbo
end

function init(
rng::Random.AbstractRNG,
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
prob,
params,
restructure,
)
q_stop = restructure(params)
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux)
return (ad_st=ad_st,)
end

function estimate_gradient!(
rng::Random.AbstractRNG,
obj::RepGradELBO,
Expand All @@ -111,9 +125,10 @@ function estimate_gradient!(
state,
)
q_stop = restructure(params)
ad_st = state.ad_st
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
value_and_gradient!(adtype, ad_st, estimate_repgradelbo_ad_forward, params, aux, out)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
return out, nothing, stat
return out, state, stat
end
4 changes: 3 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ function optimize(
)
params, restructure = Optimisers.destructure(deepcopy(q_init))
opt_st = maybe_init_optimizer(state_init, optimizer, params)
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
obj_st = maybe_init_objective(
state_init, rng, adtype, objective, problem, params, restructure
)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
stats = NamedTuple[]

Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
function maybe_init_objective(
state_init::NamedTuple,
rng::Random.AbstractRNG,
adtype::ADTypes.AbstractADType,
objective::AbstractVariationalObjective,
problem,
params,
Expand All @@ -24,7 +25,7 @@ function maybe_init_objective(
if haskey(state_init, :objective)
state_init.objective
else
init(rng, objective, problem, params, restructure)
init(rng, objective, adtype, problem, params, restructure)
end
end

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -40,7 +41,8 @@ ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
Tapir = "0.2.23"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63"
julia = "1.6"
julia = "1.7"
1 change: 1 addition & 0 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Tapir => AutoTapir(),
yebai marked this conversation as resolved.
Show resolved Hide resolved
#:Enzyme => AutoEnzyme(),
)

Expand Down
1 change: 1 addition & 0 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
(adbackname, adtype) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Tapir => AutoTapir(; safe_mode=false),
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)
Expand Down
3 changes: 2 additions & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
(adbackname, adtype) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
:Zygote => AutoZygote(),
:Tapir => AutoTapir(; safe_mode=false),
#:Enzyme => AutoEnzyme(),
)

Expand Down
11 changes: 7 additions & 4 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
using Test

@testset "ad" begin
@testset "$(adname)" for (adname, adsymbol) in Dict(
@testset "$(adname)" for (adname, adtype) in Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(),
yebai marked this conversation as resolved.
Show resolved Hide resolved
#:Enzyme => AutoEnzyme()
)
D = 10
A = randn(D, D)
λ = randn(D)
grad_buf = DiffResults.GradientResult(λ)
f(λ′) = λ′' * A * λ′ / 2
AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf)

ad_st = AdvancedVI.init_adbackend(adtype, f, λ)
grad_buf = DiffResults.GradientResult(λ)
AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf)
∇ = DiffResults.gradient(grad_buf)
f = DiffResults.value(grad_buf)
@test ∇ ≈ (A + A') * λ / 2
Expand Down
Loading
Loading