-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* moved compats for forwarddiff, reversediff, and tracker to extensions * extension for Zygote * renamd the extensions * added LazyArrays extension * moved DistributionsAD to ext * extensions should now be working * fix imports for DistributionsAD extension * fixed imports in ReverseDiff extension * fixed imports in Tracker extension * fixed imports in Zygote extension * formatting * formmatted extensions * now importing Compat as its needed by extensions * think i fixed you extension * another attempt * added missing LinearAlgebra qualification for DistributionsAD ext * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added back some changes that accidentally was dropped in the merging with master * Update ext/BijectorsDistributionsADExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing qualification * only load Requires if needed * use @static * load LAzyArrays in tests to make sure extension is working * fixed ReverseDiff * fixed imports to Zygote * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing method in BijectorsTrackerExt * removed accidentally included changes in BijectorsZygoteExt --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
3a0b7e3
commit 451a2fc
Showing
11 changed files
with
413 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
module BijectorsDistributionsADExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using Bijectors | ||
using Bijectors: LinearAlgebra | ||
using Bijectors.Distributions: AbstractMvLogNormal | ||
using DistributionsAD: | ||
TuringDirichlet, | ||
TuringWishart, | ||
TuringInverseWishart, | ||
FillVectorOfUnivariate, | ||
FillMatrixOfUnivariate, | ||
MatrixOfUnivariate, | ||
FillVectorOfMultivariate, | ||
VectorOfMultivariate, | ||
TuringScalMvNormal, | ||
TuringDiagMvNormal, | ||
TuringDenseMvNormal | ||
else | ||
using ..Bijectors | ||
using ..Bijectors: LinearAlgebra | ||
using ..Bijectors.Distributions: AbstractMvLogNormal | ||
using ..DistributionsAD: | ||
TuringDirichlet, | ||
TuringWishart, | ||
TuringInverseWishart, | ||
FillVectorOfUnivariate, | ||
FillMatrixOfUnivariate, | ||
MatrixOfUnivariate, | ||
FillVectorOfMultivariate, | ||
VectorOfMultivariate, | ||
TuringScalMvNormal, | ||
TuringDiagMvNormal, | ||
TuringDenseMvNormal | ||
end | ||
|
||
# Bijectors | ||
|
||
Bijectors.bijector(::TuringDirichlet) = Bijectors.SimplexBijector() | ||
Bijectors.bijector(::TuringWishart) = Bijectors.PDBijector() | ||
Bijectors.bijector(::TuringInverseWishart) = Bijectors.PDBijector() | ||
Bijectors.bijector(::TuringScalMvNormal) = identity | ||
Bijectors.bijector(::TuringDiagMvNormal) = identity | ||
Bijectors.bijector(::TuringDenseMvNormal) = identity | ||
|
||
function Bijectors.bijector(d::FillVectorOfUnivariate{Continuous}) | ||
return elementwise(Bijectors.bijector(d.v.value)) | ||
end | ||
function Bijectors.bijector(d::FillMatrixOfUnivariate{Continuous}) | ||
return elementwise(Bijectors.bijector(d.dists.value)) | ||
end | ||
Bijectors.bijector(d::MatrixOfUnivariate{Discrete}) = identity | ||
function Bijectors.bijector(d::MatrixOfUnivariate{Continuous}) | ||
return TruncatedBijectors.Bijector(_minmax(d.dists)...) | ||
end | ||
Bijectors.bijector(d::VectorOfMultivariate{Discrete}) = identity | ||
for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) | ||
@eval begin | ||
Bijectors.bijector(d::$T{Continuous,<:MvNormal}) = identity | ||
Bijectors.bijector(d::$T{Continuous,<:TuringScalMvNormal}) = identity | ||
Bijectors.bijector(d::$T{Continuous,<:TuringDiagMvNormal}) = identity | ||
Bijectors.bijector(d::$T{Continuous,<:TuringDenseMvNormal}) = identity | ||
Bijectors.bijector(d::$T{Continuous,<:MvNormalCanon}) = identity | ||
Bijectors.bijector(d::$T{Continuous,<:AbstractMvLogNormal}) = Log() | ||
function Bijectors.bijector(d::$T{Continuous,<:SimplexDistribution}) | ||
return Bijectors.SimplexBijector() | ||
end | ||
function Bijectors.bijector(d::$T{Continuous,<:TuringDirichlet}) | ||
return Bijectors.SimplexBijector() | ||
end | ||
end | ||
end | ||
function Bijectors.bijector(d::FillVectorOfMultivariate{Continuous}) | ||
return Bijectors.columnwise(Bijectors.bijector(d.dists.value)) | ||
end | ||
|
||
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true | ||
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true | ||
Bijectors.isdirichlet(::TuringDirichlet) = true | ||
|
||
function Bijectors.link( | ||
d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) | ||
) where {proj} | ||
return Bijectors.SimplexBijector{proj}()(x) | ||
end | ||
|
||
function Bijectors.link_jacobian( | ||
d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) | ||
) where {proj} | ||
return jacobian(Bijectors.SimplexBijector{proj}(), x) | ||
end | ||
|
||
function Bijectors.invlink( | ||
d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) | ||
) where {proj} | ||
return inverse(Bijectors.SimplexBijector{proj}())(y) | ||
end | ||
function Bijectors.invlink_jacobian( | ||
d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) | ||
) where {proj} | ||
return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y) | ||
end | ||
|
||
Bijectors.ispd(::TuringWishart) = true | ||
Bijectors.ispd(::TuringInverseWishart) = true | ||
function Bijectors.getlogp(d::TuringWishart, Xcf, X) | ||
return ( | ||
(d.df - (size(d, 1) + 1)) * LinearAlgebra.logdet(Xcf) - LinearAlgebra.tr(d.chol \ X) | ||
) / 2 + d.logc0 | ||
end | ||
function Bijectors.getlogp(d::TuringInverseWishart, Xcf, X) | ||
Ψ = d.S | ||
return -( | ||
(d.df + size(d, 1) + 1) * LinearAlgebra.logdet(Xcf) + LinearAlgebra.tr(Xcf \ Ψ) | ||
) / 2 + d.logc0 | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module BijectorsLazyArraysExt | ||
|
||
if isdefined(Base, :get_extension) | ||
import Bijectors: maporbroadcast | ||
using LazyArrays: LazyArrays | ||
else | ||
import ..Bijectors: maporbroadcast | ||
using ..LazyArrays: LazyArrays | ||
end | ||
|
||
function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) | ||
return copy(f.(x1, x...)) | ||
end | ||
function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) | ||
return copy(f.(x1, x2, x...)) | ||
end | ||
function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) | ||
return copy(f.(x1, x2, x3, x...)) | ||
end | ||
|
||
end |
127 changes: 83 additions & 44 deletions
127
src/compat/reversediff.jl → ext/BijectorsReverseDiffExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
451a2fc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
451a2fc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error while trying to register: "Tag with name
v0.12.6
already exists and points to a different commit"