From 451a2fc8937a5ac18e9b282e11c9d0a4dcb5c673 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jun 2023 23:37:26 +0100 Subject: [PATCH] Use extensions when possible (#264) * moved compats for forwarddiff, reversediff, and tracker to extensions * extension for Zygote * renamd the extensions * added LazyArrays extension * moved DistributionsAD to ext * extensions should now be working * fix imports for DistributionsAD extension * fixed imports in ReverseDiff extension * fixed imports in Tracker extension * fixed imports in Zygote extension * formatting * formmatted extensions * now importing Compat as its needed by extensions * think i fixed you extension * another attempt * added missing LinearAlgebra qualification for DistributionsAD ext * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added back some changes that accidentally was dropped in the merging with master * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing qualification * only load Requires if needed * use @static * load LAzyArrays in tests to make sure extension is working * fixed ReverseDiff * fixed imports to Zygote * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing method in BijectorsTrackerExt * removed accidentally included changes in BijectorsZygoteExt --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 16 +++ ext/BijectorsDistributionsADExt.jl | 118 ++++++++++++++++ .../BijectorsForwardDiffExt.jl | 18 ++- ext/BijectorsLazyArraysExt.jl | 21 +++ .../BijectorsReverseDiffExt.jl | 127 ++++++++++++------ .../tracker.jl => ext/BijectorsTrackerExt.jl | 79 +++++++---- .../zygote.jl => ext/BijectorsZygoteExt.jl | 83 +++++++++++- src/Bijectors.jl | 47 ++++--- src/compat/distributionsad.jl | 78 ----------- test/Project.toml | 2 + test/runtests.jl | 1 + 11 files changed, 413 insertions(+), 177 deletions(-) create mode 100644 ext/BijectorsDistributionsADExt.jl rename src/compat/forwarddiff.jl => ext/BijectorsForwardDiffExt.jl (66%) create mode 100644 ext/BijectorsLazyArraysExt.jl rename src/compat/reversediff.jl => ext/BijectorsReverseDiffExt.jl (79%) rename src/compat/tracker.jl => ext/BijectorsTrackerExt.jl (93%) rename src/compat/zygote.jl => ext/BijectorsZygoteExt.jl (72%) delete mode 100644 src/compat/distributionsad.jl diff --git a/Project.toml b/Project.toml index b4757af6..7eff12b6 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,22 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" + +[extensions] +BijectorsForwardDiffExt = "ForwardDiff" +BijectorsReverseDiffExt = "ReverseDiff" +BijectorsTrackerExt = "Tracker" +BijectorsZygoteExt = "Zygote" +BijectorsLazyArraysExt = "LazyArrays" +BijectorsDistributionsADExt = "DistributionsAD" + [compat] ArgCheck = "1, 2" ChainRulesCore = "0.10.11, 1" diff --git a/ext/BijectorsDistributionsADExt.jl b/ext/BijectorsDistributionsADExt.jl new file mode 100644 index 00000000..c34284a1 --- /dev/null +++ b/ext/BijectorsDistributionsADExt.jl @@ -0,0 +1,118 @@ +module BijectorsDistributionsADExt + +if isdefined(Base, :get_extension) + using Bijectors + using Bijectors: LinearAlgebra + using Bijectors.Distributions: AbstractMvLogNormal + using DistributionsAD: + TuringDirichlet, + TuringWishart, + TuringInverseWishart, + FillVectorOfUnivariate, + FillMatrixOfUnivariate, + MatrixOfUnivariate, + FillVectorOfMultivariate, + VectorOfMultivariate, + TuringScalMvNormal, + TuringDiagMvNormal, + TuringDenseMvNormal +else + using ..Bijectors + using ..Bijectors: LinearAlgebra + using ..Bijectors.Distributions: AbstractMvLogNormal + using ..DistributionsAD: + TuringDirichlet, + TuringWishart, + TuringInverseWishart, + FillVectorOfUnivariate, + FillMatrixOfUnivariate, + MatrixOfUnivariate, + FillVectorOfMultivariate, + VectorOfMultivariate, + TuringScalMvNormal, + TuringDiagMvNormal, + TuringDenseMvNormal +end + +# Bijectors + +Bijectors.bijector(::TuringDirichlet) = Bijectors.SimplexBijector() +Bijectors.bijector(::TuringWishart) = Bijectors.PDBijector() +Bijectors.bijector(::TuringInverseWishart) = Bijectors.PDBijector() +Bijectors.bijector(::TuringScalMvNormal) = identity +Bijectors.bijector(::TuringDiagMvNormal) = identity +Bijectors.bijector(::TuringDenseMvNormal) = identity + +function Bijectors.bijector(d::FillVectorOfUnivariate{Continuous}) + return elementwise(Bijectors.bijector(d.v.value)) +end +function Bijectors.bijector(d::FillMatrixOfUnivariate{Continuous}) + return elementwise(Bijectors.bijector(d.dists.value)) +end +Bijectors.bijector(d::MatrixOfUnivariate{Discrete}) = identity +function Bijectors.bijector(d::MatrixOfUnivariate{Continuous}) + return TruncatedBijectors.Bijector(_minmax(d.dists)...) +end +Bijectors.bijector(d::VectorOfMultivariate{Discrete}) = identity +for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) + @eval begin + Bijectors.bijector(d::$T{Continuous,<:MvNormal}) = identity + Bijectors.bijector(d::$T{Continuous,<:TuringScalMvNormal}) = identity + Bijectors.bijector(d::$T{Continuous,<:TuringDiagMvNormal}) = identity + Bijectors.bijector(d::$T{Continuous,<:TuringDenseMvNormal}) = identity + Bijectors.bijector(d::$T{Continuous,<:MvNormalCanon}) = identity + Bijectors.bijector(d::$T{Continuous,<:AbstractMvLogNormal}) = Log() + function Bijectors.bijector(d::$T{Continuous,<:SimplexDistribution}) + return Bijectors.SimplexBijector() + end + function Bijectors.bijector(d::$T{Continuous,<:TuringDirichlet}) + return Bijectors.SimplexBijector() + end + end +end +function Bijectors.bijector(d::FillVectorOfMultivariate{Continuous}) + return Bijectors.columnwise(Bijectors.bijector(d.dists.value)) +end + +Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true +Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true +Bijectors.isdirichlet(::TuringDirichlet) = true + +function Bijectors.link( + d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) +) where {proj} + return Bijectors.SimplexBijector{proj}()(x) +end + +function Bijectors.link_jacobian( + d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) +) where {proj} + return jacobian(Bijectors.SimplexBijector{proj}(), x) +end + +function Bijectors.invlink( + d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) +) where {proj} + return inverse(Bijectors.SimplexBijector{proj}())(y) +end +function Bijectors.invlink_jacobian( + d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) +) where {proj} + return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y) +end + +Bijectors.ispd(::TuringWishart) = true +Bijectors.ispd(::TuringInverseWishart) = true +function Bijectors.getlogp(d::TuringWishart, Xcf, X) + return ( + (d.df - (size(d, 1) + 1)) * LinearAlgebra.logdet(Xcf) - LinearAlgebra.tr(d.chol \ X) + ) / 2 + d.logc0 +end +function Bijectors.getlogp(d::TuringInverseWishart, Xcf, X) + Ψ = d.S + return -( + (d.df + size(d, 1) + 1) * LinearAlgebra.logdet(Xcf) + LinearAlgebra.tr(Xcf \ Ψ) + ) / 2 + d.logc0 +end + +end diff --git a/src/compat/forwarddiff.jl b/ext/BijectorsForwardDiffExt.jl similarity index 66% rename from src/compat/forwarddiff.jl rename to ext/BijectorsForwardDiffExt.jl index 1b51bb0b..29db3028 100644 --- a/src/compat/forwarddiff.jl +++ b/ext/BijectorsForwardDiffExt.jl @@ -1,11 +1,19 @@ -import .ForwardDiff +module BijectorsForwardDiffExt -_eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = _eps(Real) -_eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = _eps(Real) +if isdefined(Base, :get_extension) + using Bijectors: Bijectors, find_alpha + using ForwardDiff: ForwardDiff +else + using ..Bijectors: Bijectors, find_alpha + using ..ForwardDiff: ForwardDiff +end + +Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = Bijectors._eps(Real) +Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = Bijectors._eps(Real) # Define forward-mode rule for ForwardDiff and don't trust support for ForwardDiff in Roots # https://github.com/JuliaMath/Roots.jl/issues/314 -function find_alpha( +function Bijectors.find_alpha( wt_y::ForwardDiff.Dual{T,<:Real}, wt_u_hat::ForwardDiff.Dual{T,<:Real}, b::ForwardDiff.Dual{T,<:Real}, @@ -25,3 +33,5 @@ function find_alpha( return ForwardDiff.Dual{T}(Ω, ∂Ω) end + +end diff --git a/ext/BijectorsLazyArraysExt.jl b/ext/BijectorsLazyArraysExt.jl new file mode 100644 index 00000000..fa060470 --- /dev/null +++ b/ext/BijectorsLazyArraysExt.jl @@ -0,0 +1,21 @@ +module BijectorsLazyArraysExt + +if isdefined(Base, :get_extension) + import Bijectors: maporbroadcast + using LazyArrays: LazyArrays +else + import ..Bijectors: maporbroadcast + using ..LazyArrays: LazyArrays +end + +function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x...)) +end +function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x2, x...)) +end +function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x2, x3, x...)) +end + +end diff --git a/src/compat/reversediff.jl b/ext/BijectorsReverseDiffExt.jl similarity index 79% rename from src/compat/reversediff.jl rename to ext/BijectorsReverseDiffExt.jl index 7e95e69c..ef0cfebd 100644 --- a/src/compat/reversediff.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -1,47 +1,86 @@ -module ReverseDiffCompat - -using ..ReverseDiff: - ReverseDiff, - @grad, - value, - track, - TrackedReal, - TrackedVector, - TrackedMatrix, - @grad_from_chainrules -using Requires, LinearAlgebra - -using ..Bijectors: - Elementwise, - SimplexBijector, - maphcat, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse -import ..Bijectors: - _eps, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - pd_from_lower, - pd_from_upper, - lower_triangular, - upper_triangular, - _inv_link_chol_lkj, - _link_chol_lkj, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - cholesky_factor - -using ChainRulesCore: ChainRulesCore - -using Compat: eachcol -using Distributions: LocationScale +module BijectorsReverseDiffExt + +if isdefined(Base, :get_extension) + using ReverseDiff: + ReverseDiff, + @grad, + value, + track, + TrackedReal, + TrackedVector, + TrackedMatrix, + @grad_from_chainrules + + using Bijectors: + ChainRulesCore, + Elementwise, + SimplexBijector, + maphcat, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse + import Bijectors: + _eps, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_from_lower, + lower_triangular, + upper_triangular + + using Bijectors.LinearAlgebra + using Bijectors.Compat: eachcol + using Bijectors.Distributions: LocationScale +else + using ..ReverseDiff: + ReverseDiff, + @grad, + value, + track, + TrackedReal, + TrackedVector, + TrackedMatrix, + @grad_from_chainrules + + using ..Bijectors: + ChainRulesCore, + Elementwise, + SimplexBijector, + maphcat, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse + import ..Bijectors: + _eps, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_from_lower, + lower_triangular, + upper_triangular + + using ..Bijectors.LinearAlgebra + using ..Bijectors.Compat: eachcol + using ..Bijectors.Distributions: LocationScale +end _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) diff --git a/src/compat/tracker.jl b/ext/BijectorsTrackerExt.jl similarity index 93% rename from src/compat/tracker.jl rename to ext/BijectorsTrackerExt.jl index 72925fce..bc12c46d 100644 --- a/src/compat/tracker.jl +++ b/ext/BijectorsTrackerExt.jl @@ -1,26 +1,58 @@ -module TrackerCompat - -using ..Tracker: - Tracker, - TrackedReal, - TrackedVector, - TrackedMatrix, - TrackedArray, - TrackedVecOrMat, - @grad, - track, - data, - param - -import ..Bijectors -using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked, _triu1_dim_from_length - -using ChainRulesCore: ChainRulesCore -using LogExpFunctions: LogExpFunctions - -using Compat: eachcol -using LinearAlgebra -using Distributions: LocationScale +module BijectorsTrackerExt + +if isdefined(Base, :get_extension) + using Tracker: + Tracker, + TrackedReal, + TrackedVector, + TrackedMatrix, + TrackedArray, + TrackedVecOrMat, + @grad, + track, + data, + param + + using Bijectors: + Elementwise, + SimplexBijector, + Inverse, + Stacked, + Bijectors, + ChainRulesCore, + LogExpFunctions, + _triu1_dim_from_length + + using Bijectors.LinearAlgebra + using Bijectors.Compat: eachcol + using Bijectors.Distributions: LocationScale +else + using ..Tracker: + Tracker, + TrackedReal, + TrackedVector, + TrackedMatrix, + TrackedArray, + TrackedVecOrMat, + @grad, + track, + data, + param + + using Bijectors: + Elementwise, + SimplexBijector, + Inverse, + Stacked, + Bijectors, + ChainRulesCore, + LogExpFunctions, + _triu1_dim_from_length + + using ..Bijectors.LinearAlgebra + using ..Bijectors.Compat: eachcol + using ..Bijectors.Distributions: LocationScale +end Bijectors.maporbroadcast(f, x::TrackedArray...) = f.(x...) function Bijectors.maporbroadcast( @@ -283,7 +315,6 @@ function vectorof(::Type{TrackedReal{T}}) where {T<:Real} return TrackedArray{T,1,Vector{T}} end -(b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) (b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) (b::Elementwise{typeof(exp)})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) diff --git a/src/compat/zygote.jl b/ext/BijectorsZygoteExt.jl similarity index 72% rename from src/compat/zygote.jl rename to ext/BijectorsZygoteExt.jl index f0f23538..fd59706b 100644 --- a/src/compat/zygote.jl +++ b/ext/BijectorsZygoteExt.jl @@ -1,6 +1,78 @@ -using .Zygote: Zygote, @adjoint, pullback +module BijectorsZygoteExt -using Compat: eachcol +if isdefined(Base, :get_extension) + using Zygote: Zygote, @adjoint, pullback + using Bijectors: + Elementwise, + SimplexBijector, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse, + maphcat, + IrrationalConstants, + Distributions, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_logpdf_with_trans, + istraining, + mapvcat, + eachcolmaphcat, + sumeachcol, + pd_link, + pd_from_lower, + lower_triangular, + upper_triangular + + using Bijectors.LinearAlgebra + using Bijectors.Compat: eachcol + using Bijectors.Distributions: LocationScale +else + using ..Zygote: Zygote, @adjoint, pullback + using ..Bijectors: + Elementwise, + SimplexBijector, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse, + maphcat, + IrrationalConstants, + Distributions, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_logpdf_with_trans, + istraining, + mapvcat, + eachcolmaphcat, + sumeachcol, + pd_link, + pd_from_lower, + lower_triangular, + upper_triangular + + using ..Bijectors.LinearAlgebra + using ..Bijectors.Compat: eachcol + using ..Bijectors.Distributions: LocationScale +end @adjoint istraining() = true, _ -> nothing @@ -131,8 +203,8 @@ end end # LocationScale fix - -@adjoint function minimum(d::LocationScale) +# TODO: Remove this. +@adjoint function Base.minimum(d::Distributions.LocationScale) function _minimum(d) m = minimum(d.ρ) if isfinite(m) @@ -143,7 +215,7 @@ end end return pullback(_minimum, d) end -@adjoint function maximum(d::LocationScale) +@adjoint function Base.maximum(d::LocationScale) function _maximum(d) m = maximum(d.ρ) if isfinite(m) @@ -170,3 +242,4 @@ end return replace_diag(log, Y) end end +end diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 605fe002..46b31fb3 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -28,7 +28,7 @@ module Bijectors > SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. =# -using Reexport, Requires +using Reexport @reexport using Distributions using LinearAlgebra using MappedArrays @@ -45,6 +45,7 @@ using Functors: Functors using IrrationalConstants: IrrationalConstants using LogExpFunctions: LogExpFunctions using Roots: Roots +using Compat: Compat export TransformDistribution, PositiveDistribution, @@ -279,29 +280,31 @@ maporbroadcast(f, x::AbstractArray{<:Any,N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) # optional dependencies +if !isdefined(Base, :get_extension) + using Requires +end + function __init__() - @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin - function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x...)) - end - function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x...)) - end - function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x3, x...)) - end + @static if !isdefined(Base, :get_extension) + @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" include( + "../ext/BijectorsLazyArraysExt.jl" + ) + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( + "../ext/BijectorsForwardDiffExt.jl" + ) + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include( + "../ext/BijectorsTrackerExt.jl" + ) + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include( + "../ext/BijectorsZygoteExt.jl" + ) + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( + "../ext/BijectorsReverseDiffExt.jl" + ) + @require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include( + "../ext/BijectorsDistributionsADExt.jl" + ) end - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "compat/forwarddiff.jl" - ) - @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "compat/reversediff.jl" - ) - @require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include( - "compat/distributionsad.jl" - ) end end # module diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl deleted file mode 100644 index af5397f9..00000000 --- a/src/compat/distributionsad.jl +++ /dev/null @@ -1,78 +0,0 @@ -using .DistributionsAD: - TuringDirichlet, - TuringWishart, - TuringInverseWishart, - FillVectorOfUnivariate, - FillMatrixOfUnivariate, - MatrixOfUnivariate, - FillVectorOfMultivariate, - VectorOfMultivariate, - TuringScalMvNormal, - TuringDiagMvNormal, - TuringDenseMvNormal -using Distributions: AbstractMvLogNormal - -# Bijectors - -bijector(::TuringDirichlet) = SimplexBijector() -bijector(::TuringWishart) = PDBijector() -bijector(::TuringInverseWishart) = PDBijector() -bijector(::TuringScalMvNormal) = identity -bijector(::TuringDiagMvNormal) = identity -bijector(::TuringDenseMvNormal) = identity - -bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value)) -bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value)) -bijector(d::MatrixOfUnivariate{Discrete}) = identity -bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) -bijector(d::VectorOfMultivariate{Discrete}) = identity -for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) - @eval begin - bijector(d::$T{Continuous,<:MvNormal}) = identity - bijector(d::$T{Continuous,<:TuringScalMvNormal}) = identity - bijector(d::$T{Continuous,<:TuringDiagMvNormal}) = identity - bijector(d::$T{Continuous,<:TuringDenseMvNormal}) = identity - bijector(d::$T{Continuous,<:MvNormalCanon}) = identity - bijector(d::$T{Continuous,<:AbstractMvLogNormal}) = Log() - bijector(d::$T{Continuous,<:SimplexDistribution}) = SimplexBijector() - bijector(d::$T{Continuous,<:TuringDirichlet}) = SimplexBijector() - end -end -bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value)) - -isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true -isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true -isdirichlet(::TuringDirichlet) = true - -function link( - d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return SimplexBijector{proj}()(x) -end - -function link_jacobian( - d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(SimplexBijector{proj}(), x) -end - -function invlink( - d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return inverse(SimplexBijector{proj}())(y) -end -function invlink_jacobian( - d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(inverse(SimplexBijector{proj}()), y) -end - -ispd(::TuringWishart) = true -ispd(::TuringInverseWishart) = true -function getlogp(d::TuringWishart, Xcf, X) - return ((d.df - (size(d, 1) + 1)) * logdet(Xcf) - tr(d.chol \ X)) / 2 + d.logc0 -end -function getlogp(d::TuringInverseWishart, Xcf, X) - Ψ = d.S - return -((d.df + size(d, 1) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) / 2 + d.logc0 -end diff --git a/test/Project.toml b/test/Project.toml index 51166ab9..359e5869 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -24,6 +25,7 @@ FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" +LazyArrays = "1" LogExpFunctions = "0.3.1" ReverseDiff = "1.4.2" Tracker = "0.2.11" diff --git a/test/runtests.jl b/test/runtests.jl index cb8d9455..a819dfd2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ using Bijectors: using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions +using LazyArrays: LazyArrays const GROUP = get(ENV, "GROUP", "All")