Skip to content

Commit

Permalink
Use extensions when possible (#264)
Browse files Browse the repository at this point in the history
* moved compats for forwarddiff, reversediff, and tracker to extensions

* extension for Zygote

* renamd the extensions

* added LazyArrays extension

* moved DistributionsAD to ext

* extensions should now be working

* fix imports for DistributionsAD extension

* fixed imports in ReverseDiff extension

* fixed imports in Tracker extension

* fixed imports in Zygote extension

* formatting

* formmatted extensions

* now importing Compat as its needed by extensions

* think i fixed you extension

* another attempt

* added missing LinearAlgebra qualification for DistributionsAD ext

* Update ext/BijectorsDistributionsADExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update ext/BijectorsDistributionsADExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added back some changes that accidentally was dropped in the merging
with master

* Update ext/BijectorsDistributionsADExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added missing qualification

* only load Requires if needed

* use @static

* load LAzyArrays in tests to make sure extension is working

* fixed ReverseDiff

* fixed imports to Zygote

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added missing method in BijectorsTrackerExt

* removed accidentally included changes in BijectorsZygoteExt

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 13, 2023
1 parent 3a0b7e3 commit 451a2fc
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 177 deletions.
16 changes: 16 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

[extensions]
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsZygoteExt = "Zygote"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsDistributionsADExt = "DistributionsAD"

[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.10.11, 1"
Expand Down
118 changes: 118 additions & 0 deletions ext/BijectorsDistributionsADExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module BijectorsDistributionsADExt

if isdefined(Base, :get_extension)
using Bijectors
using Bijectors: LinearAlgebra
using Bijectors.Distributions: AbstractMvLogNormal
using DistributionsAD:
TuringDirichlet,
TuringWishart,
TuringInverseWishart,
FillVectorOfUnivariate,
FillMatrixOfUnivariate,
MatrixOfUnivariate,
FillVectorOfMultivariate,
VectorOfMultivariate,
TuringScalMvNormal,
TuringDiagMvNormal,
TuringDenseMvNormal
else
using ..Bijectors
using ..Bijectors: LinearAlgebra
using ..Bijectors.Distributions: AbstractMvLogNormal
using ..DistributionsAD:
TuringDirichlet,
TuringWishart,
TuringInverseWishart,
FillVectorOfUnivariate,
FillMatrixOfUnivariate,
MatrixOfUnivariate,
FillVectorOfMultivariate,
VectorOfMultivariate,
TuringScalMvNormal,
TuringDiagMvNormal,
TuringDenseMvNormal
end

# Bijectors

Bijectors.bijector(::TuringDirichlet) = Bijectors.SimplexBijector()
Bijectors.bijector(::TuringWishart) = Bijectors.PDBijector()
Bijectors.bijector(::TuringInverseWishart) = Bijectors.PDBijector()
Bijectors.bijector(::TuringScalMvNormal) = identity
Bijectors.bijector(::TuringDiagMvNormal) = identity
Bijectors.bijector(::TuringDenseMvNormal) = identity

function Bijectors.bijector(d::FillVectorOfUnivariate{Continuous})
return elementwise(Bijectors.bijector(d.v.value))
end
function Bijectors.bijector(d::FillMatrixOfUnivariate{Continuous})
return elementwise(Bijectors.bijector(d.dists.value))
end
Bijectors.bijector(d::MatrixOfUnivariate{Discrete}) = identity
function Bijectors.bijector(d::MatrixOfUnivariate{Continuous})
return TruncatedBijectors.Bijector(_minmax(d.dists)...)
end
Bijectors.bijector(d::VectorOfMultivariate{Discrete}) = identity
for T in (:VectorOfMultivariate, :FillVectorOfMultivariate)
@eval begin
Bijectors.bijector(d::$T{Continuous,<:MvNormal}) = identity
Bijectors.bijector(d::$T{Continuous,<:TuringScalMvNormal}) = identity
Bijectors.bijector(d::$T{Continuous,<:TuringDiagMvNormal}) = identity
Bijectors.bijector(d::$T{Continuous,<:TuringDenseMvNormal}) = identity
Bijectors.bijector(d::$T{Continuous,<:MvNormalCanon}) = identity
Bijectors.bijector(d::$T{Continuous,<:AbstractMvLogNormal}) = Log()
function Bijectors.bijector(d::$T{Continuous,<:SimplexDistribution})
return Bijectors.SimplexBijector()
end
function Bijectors.bijector(d::$T{Continuous,<:TuringDirichlet})
return Bijectors.SimplexBijector()
end
end
end
function Bijectors.bijector(d::FillVectorOfMultivariate{Continuous})
return Bijectors.columnwise(Bijectors.bijector(d.dists.value))
end

Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true
Bijectors.isdirichlet(::TuringDirichlet) = true

function Bijectors.link(
d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return Bijectors.SimplexBijector{proj}()(x)
end

function Bijectors.link_jacobian(
d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(Bijectors.SimplexBijector{proj}(), x)
end

function Bijectors.invlink(
d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return inverse(Bijectors.SimplexBijector{proj}())(y)
end
function Bijectors.invlink_jacobian(
d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y)
end

Bijectors.ispd(::TuringWishart) = true
Bijectors.ispd(::TuringInverseWishart) = true
function Bijectors.getlogp(d::TuringWishart, Xcf, X)
return (
(d.df - (size(d, 1) + 1)) * LinearAlgebra.logdet(Xcf) - LinearAlgebra.tr(d.chol \ X)
) / 2 + d.logc0
end
function Bijectors.getlogp(d::TuringInverseWishart, Xcf, X)
Ψ = d.S
return -(
(d.df + size(d, 1) + 1) * LinearAlgebra.logdet(Xcf) + LinearAlgebra.tr(Xcf \ Ψ)
) / 2 + d.logc0
end

end
18 changes: 14 additions & 4 deletions src/compat/forwarddiff.jl → ext/BijectorsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import .ForwardDiff
module BijectorsForwardDiffExt

_eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = _eps(Real)
_eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = _eps(Real)
if isdefined(Base, :get_extension)
using Bijectors: Bijectors, find_alpha
using ForwardDiff: ForwardDiff
else
using ..Bijectors: Bijectors, find_alpha
using ..ForwardDiff: ForwardDiff
end

Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = Bijectors._eps(Real)
Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = Bijectors._eps(Real)

# Define forward-mode rule for ForwardDiff and don't trust support for ForwardDiff in Roots
# https://github.com/JuliaMath/Roots.jl/issues/314
function find_alpha(
function Bijectors.find_alpha(
wt_y::ForwardDiff.Dual{T,<:Real},
wt_u_hat::ForwardDiff.Dual{T,<:Real},
b::ForwardDiff.Dual{T,<:Real},
Expand All @@ -25,3 +33,5 @@ function find_alpha(

return ForwardDiff.Dual{T}(Ω, ∂Ω)
end

end
21 changes: 21 additions & 0 deletions ext/BijectorsLazyArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module BijectorsLazyArraysExt

if isdefined(Base, :get_extension)
import Bijectors: maporbroadcast
using LazyArrays: LazyArrays
else
import ..Bijectors: maporbroadcast
using ..LazyArrays: LazyArrays
end

function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x...))
end
function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x2, x...))
end
function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x2, x3, x...))
end

end
127 changes: 83 additions & 44 deletions src/compat/reversediff.jl → ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,86 @@
module ReverseDiffCompat

using ..ReverseDiff:
ReverseDiff,
@grad,
value,
track,
TrackedReal,
TrackedVector,
TrackedMatrix,
@grad_from_chainrules
using Requires, LinearAlgebra

using ..Bijectors:
Elementwise,
SimplexBijector,
maphcat,
simplex_link_jacobian,
simplex_invlink_jacobian,
simplex_logabsdetjac_gradient,
Inverse
import ..Bijectors:
_eps,
logabsdetjac,
_logabsdetjac_scale,
_simplex_bijector,
_simplex_inv_bijector,
replace_diag,
jacobian,
pd_from_lower,
pd_from_upper,
lower_triangular,
upper_triangular,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
cholesky_factor

using ChainRulesCore: ChainRulesCore

using Compat: eachcol
using Distributions: LocationScale
module BijectorsReverseDiffExt

if isdefined(Base, :get_extension)
using ReverseDiff:
ReverseDiff,
@grad,
value,
track,
TrackedReal,
TrackedVector,
TrackedMatrix,
@grad_from_chainrules

using Bijectors:
ChainRulesCore,
Elementwise,
SimplexBijector,
maphcat,
simplex_link_jacobian,
simplex_invlink_jacobian,
simplex_logabsdetjac_gradient,
Inverse
import Bijectors:
_eps,
logabsdetjac,
_logabsdetjac_scale,
_simplex_bijector,
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular

using Bijectors.LinearAlgebra
using Bijectors.Compat: eachcol
using Bijectors.Distributions: LocationScale
else
using ..ReverseDiff:
ReverseDiff,
@grad,
value,
track,
TrackedReal,
TrackedVector,
TrackedMatrix,
@grad_from_chainrules

using ..Bijectors:
ChainRulesCore,
Elementwise,
SimplexBijector,
maphcat,
simplex_link_jacobian,
simplex_invlink_jacobian,
simplex_logabsdetjac_gradient,
Inverse
import ..Bijectors:
_eps,
logabsdetjac,
_logabsdetjac_scale,
_simplex_bijector,
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular

using ..Bijectors.LinearAlgebra
using ..Bijectors.Compat: eachcol
using ..Bijectors.Distributions: LocationScale
end

_eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T)
function Base.minimum(d::LocationScale{<:TrackedReal})
Expand Down
Loading

2 comments on commit 451a2fc

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.12.6 already exists and points to a different commit"

Please sign in to comment.