Skip to content

Commit

Permalink
Implementation of VecCorrBijector (#246)
Browse files Browse the repository at this point in the history
* initial work on VecCorrBijector

* added some tests for CorrBijector, and fixed implementation for VecCorrBijector

* improved tests and are now using integer sqrt and division

* moved things around a bit

* added chainrule for ReverseDiff

* some fixes for AD

* added some TODOs

* Update src/bijectors/corr.jl

* define bijectors for `LKJ` and `LKJCholesky`

* add `TransformedDistribution` constructor
for `LKJCholesky`

* define `logpdf` for `LKJ` & `LKJCholesky`

* define `rand` for `LKJ` & `LKJCholesky`

* add util to extract Cholesky factor

* TYPO: capitalize matrix

* add util to convert `Vector` index
to `Matrix` row index

* add `VecTriBijector`s for `LKJCholesky`

* TYPO: capitilize matrix

* add `LKJCholesky` link for `UpperTriangular`

* add `LKJCholesky` link for `LowerTriangular`

* TYPO: capitalize matrix

* add `LKJCholesky` inverse link to `UpperTriangular`

* rename `_logabsdetjac_chol_lkj`
to `_logabsdetjac_inv_corr`

* dispatch `_logabsdetjac_inv_corr` for `::Vector`

* add logabsdetjac for inverse link of `LKJCholesky`

* add tests for `VecTriBijector`s

* add `rrule` for LKJ(Cholesky) link function

* use `transpose` in link for `::LowerTriangular'

* add `Tracker` support for inverse link

* better utility function call

* use function barrier properly for type stability

* account for difference in support dimensions

* fix indexing in Jacobian of `VecCorrBijector`

* add `_logabsdetjac_dist` for `::LKJCholesky`

* replace function composition for proper barrier

* add util convert `Transpose -> Matrix` for type stability

* add `LKJCholesky` Jacobian+type tests

* fix `logabsdetjac` for inverse link

* use `Cholesky` constructor compatible with `v1.6`

* add empty line

* fix `rrule` for link function

* add link `rrule` test

* add `rrule` for inverse link

* remove TODO

* add inverse link `rrule` test

* Update src/bijectors/corr.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* add link `rrule` for `LowerTriangular`

* add `LowerTriangular` chainrule test

* Update src/bijectors/corr.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* remove unused util

* use `similar` instead of `zeros`

* update comments

* remove old comment

* minimize zero-setting operations in inverse link

* minimize zero-setting operations in `rrule`

* add parametric `Val` type to `VecCorrBijector`

* update `VecCorrBijector` tests

* use field value instead of `Val`-parametric type

* update tests with new `VecCorrBijector`

* `using VecCorrBijector` in test utils

* add `VecCorrBijector.mode` check

* update `VecCorrBijector` docstring

* specialise `Zygote@adjoint` for `AbstractMatrix`

* `ReverseDiff` opt-in to `ChainRules`

* empty lines format

* add AD test for inverse link

* include `VecCorrBijector` tests

* remove broken flag for `Tracker`

* add roundtrip AD tests for `VecCorrBijector`

* remove wrong `ReverseDiff.@grad` for `pd_from_upper`

* add corrected `rrule` for `pd_from_upper`

* update AD tests

* remove `Tracker` from broken

* update zero-filling in `Tracker` pullback

* fix `Zygote`

* merge lines - applying feedback suggestions

* `unthunk` in `pd_from_upper` rrule

* split structs into `VecCorrBijector` and `VecCholeskyBijector`

* remove old `Zygote` adjoints

* update tests

* fix `Union` in `@inferred` after splitting structs

* remove `Tracker` tests as support is dropped

* use `permutedims` instead of casting

* remove `Union` in `@inferred`

* initial work on VecCorrBijector

* added some tests for CorrBijector, and fixed implementation for VecCorrBijector

* improved tests and are now using integer sqrt and division

* moved things around a bit

* added chainrule for ReverseDiff

* some fixes for AD

* added some TODOs

* define bijectors for `LKJ` and `LKJCholesky`

* add `TransformedDistribution` constructor
for `LKJCholesky`

* define `logpdf` for `LKJ` & `LKJCholesky`

* define `rand` for `LKJ` & `LKJCholesky`

* add util to extract Cholesky factor

* TYPO: capitalize matrix

* add util to convert `Vector` index
to `Matrix` row index

* add `VecTriBijector`s for `LKJCholesky`

* TYPO: capitilize matrix

* add `LKJCholesky` link for `UpperTriangular`

* add `LKJCholesky` link for `LowerTriangular`

* TYPO: capitalize matrix

* add `LKJCholesky` inverse link to `UpperTriangular`

* rename `_logabsdetjac_chol_lkj`
to `_logabsdetjac_inv_corr`

* dispatch `_logabsdetjac_inv_corr` for `::Vector`

* add logabsdetjac for inverse link of `LKJCholesky`

* add tests for `VecTriBijector`s

* add `rrule` for LKJ(Cholesky) link function

* use `transpose` in link for `::LowerTriangular'

* add `Tracker` support for inverse link

* better utility function call

* use function barrier properly for type stability

* account for difference in support dimensions

* fix indexing in Jacobian of `VecCorrBijector`

* add `_logabsdetjac_dist` for `::LKJCholesky`

* replace function composition for proper barrier

* add util convert `Transpose -> Matrix` for type stability

* add `LKJCholesky` Jacobian+type tests

* fix `logabsdetjac` for inverse link

* use `Cholesky` constructor compatible with `v1.6`

* add empty line

* fix `rrule` for link function

* add link `rrule` test

* add `rrule` for inverse link

* remove TODO

* add inverse link `rrule` test

* Update src/bijectors/corr.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* add link `rrule` for `LowerTriangular`

* add `LowerTriangular` chainrule test

* Update src/bijectors/corr.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* remove unused util

* use `similar` instead of `zeros`

* update comments

* remove old comment

* minimize zero-setting operations in inverse link

* minimize zero-setting operations in `rrule`

* add parametric `Val` type to `VecCorrBijector`

* update `VecCorrBijector` tests

* use field value instead of `Val`-parametric type

* update tests with new `VecCorrBijector`

* `using VecCorrBijector` in test utils

* add `VecCorrBijector.mode` check

* update `VecCorrBijector` docstring

* specialise `Zygote@adjoint` for `AbstractMatrix`

* `ReverseDiff` opt-in to `ChainRules`

* empty lines format

* add AD test for inverse link

* include `VecCorrBijector` tests

* remove broken flag for `Tracker`

* add roundtrip AD tests for `VecCorrBijector`

* remove wrong `ReverseDiff.@grad` for `pd_from_upper`

* add corrected `rrule` for `pd_from_upper`

* update AD tests

* remove `Tracker` from broken

* update zero-filling in `Tracker` pullback

* fix `Zygote`

* merge lines - applying feedback suggestions

* `unthunk` in `pd_from_upper` rrule

* split structs into `VecCorrBijector` and `VecCholeskyBijector`

* remove old `Zygote` adjoints

* update tests

* fix `Union` in `@inferred` after splitting structs

* remove `Tracker` tests as support is dropped

* use `permutedims` instead of casting

* remove `Union` in `@inferred`

* wrap matrix in `Hermitian` before `cholesky`

* add hacky dispatch for `cholesky_factor`  and `ReverseDiff`

* import `cholesky_factor` in ReverseDiff module for hacky dispatch

* only use hacky `cholesky_factor` in versions before fix

* change `LKJCholesky` shape to avoid stochastic test failures

* remove old TODOs

* add explicit zero-filling in link for `CorrBijector`

---------

Co-authored-by: harisorgn <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Jun 12, 2023
1 parent 24fa396 commit 3a0b7e3
Show file tree
Hide file tree
Showing 14 changed files with 777 additions and 181 deletions.
5 changes: 4 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using Reexport, Requires
using LinearAlgebra
using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular
using LinearAlgebra: AbstractTriangular, Hermitian

using InverseFunctions: InverseFunctions

Expand Down Expand Up @@ -145,6 +145,9 @@ function _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractM
return logabsdetjac.((bijector(d),), x)
end

_logabsdetjac_dist(d::LKJCholesky, x::Cholesky) = logabsdetjac(bijector(d), x)
_logabsdetjac_dist(d::LKJCholesky, x::AbstractVector) = logabsdetjac.((bijector(d),), x)

function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down
Loading

0 comments on commit 3a0b7e3

Please sign in to comment.