Skip to content

Commit

Permalink
Add support for distributions with monotonically increasing bijector (
Browse files Browse the repository at this point in the history
#297)

* add support for `ordered` when bijector is monotonically increasing

* bump patch version

* Update src/interface.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* formatting

* Update src/interface.jl

Co-authored-by: Seth Axen <[email protected]>

* added more impls of is_monotonically_increasing

* added `is_monotonically_increasing` for `Shift`

* reverted is_monotonicall_increasing impl for Scale but added for
Shift, as originally intended

* added impl of `is_monotonically_decreasing` and corrected impls for compositions

* added monotonic impls for `Scale`

* added monotonic impls for `TruncatedBijector`

* `ordered` now also supports monotonically decreasing transformations

* added `inverse` impl for `SignFlip`

* fixed `output_size` for `SignFlip`

* formatting

* another test case

* updated a comment

* added some additional comments

* Apply suggestions from code review

* Update test/bijectors/ordered.jl

* added `OrderedDistribution` to address bugs in current `ordered`

* return `OrderedDistribution` from `ordered`

* move the `ordered` definition be near `OrderedDistribution`

* initial work on adding tests

* added currently failing correctness tests

* fixed `rand` for `OrderedDistribution`

* more extensive correctness testing of `ordered`

* test ordered for higher dims

* Apply suggestions from code review

Co-authored-by: Seth Axen <[email protected]>

* don't use `InverseGamma` as target due to heavy tails

* Update src/bijectors/ordered.jl

Co-authored-by: Seth Axen <[email protected]>

* fixed syntax error

* fixed OrderedBijector + added some docs for it

* forgot to uncomment tests in previous commit + fixed them

* more test uncommented

* fixed failing tests for LKJ

* Update test/ad/chainrules.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/ad/chainrules.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/ad/chainrules.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* better initialization for ordered chains

* added the description of the un-normalized `oredered` issue

* fixed docstring of OrderedBijeector

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
3 people authored Jun 27, 2024
1 parent edcdf8c commit 026a07a
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 33 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down Expand Up @@ -45,8 +46,9 @@ ChainRulesCore = "0.10.11, 1"
ChangesOfVariables = "0.1"
Compat = "3.46, 4.2"
Distributions = "0.25.33"
ForwardDiff = "0.10"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
IrrationalConstants = "0.1, 0.2"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ using IrrationalConstants: IrrationalConstants
using LogExpFunctions: LogExpFunctions
using Roots: Roots
using Compat: Compat
using DocStringExtensions: TYPEDFIELDS

export TransformDistribution,
PositiveDistribution,
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/exp_log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x)

logabsdetjac(b::typeof(log), x::Real) = -log(x)
logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x)

is_monotonically_increasing(::typeof(exp)) = true
is_monotonically_increasing(::typeof(log)) = true
2 changes: 2 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ function with_logabsdet_jacobian(b::LeakyReLU, x::AbstractArray)
J = mask .* b.α .+ (!).(mask)
return J .* x, sum(log.(abs.(J)))
end

is_monotonically_increasing(::LeakyReLU) = true
2 changes: 2 additions & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b))
function with_logabsdet_jacobian(b::Logit, x)
return _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b))
end

is_monotonically_increasing(::Logit) = true
116 changes: 97 additions & 19 deletions src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
struct SignFlip <: Bijector end

with_logabsdet_jacobian(::SignFlip, x) = -x, zero(eltype(x))
inverse(::SignFlip) = SignFlip()
output_size(::SignFlip, dim) = dim
is_monotonically_increasing(::SignFlip) = false
is_monotonically_decreasing(::SignFlip) = true

"""
OrderedBijector()
A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ.
A bijector mapping unordered vectors in ℝᵈ to ordered vectors in ℝᵈ.
## See also
- [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html)
- Note that this transformation and its inverse are the _opposite_ of in this reference.
"""
struct OrderedBijector <: Bijector end

"""
ordered(d::Distribution)
Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.
This transformation is currently only supported for otherwise unconstrained distributions.
"""
function ordered(d::ContinuousMultivariateDistribution)
if bijector(d) !== identity
throw(
ArgumentError(
"ordered transform is currently only supported for unconstrained distributions.",
),
)
end
return transformed(d, OrderedBijector())
end

with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y)
Expand Down Expand Up @@ -88,3 +78,91 @@ end

logabsdetjac(b::OrderedBijector, x::AbstractVector) = sum(@view(x[2:end]))
logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims=1))

# Need a custom distribution type to handle this properly.
"""
OrderedDistribution
Wraps a distribution to restrict its support to the subspace of ordered vectors.
# Fields
$(TYPEDFIELDS)
"""
struct OrderedDistribution{D<:ContinuousMultivariateDistribution,B} <:
ContinuousMultivariateDistribution
"distribution transformed to have ordered support"
dist::D
"transformation from constrained space to ordered unconstrained space"
transform::B
end

"""
ordered(d::Distribution)
Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.
Specifically, `d` is restricted to the subspace of its domain containing only ordered elements.
!!! warning
`rand` is implemented using rejection sampling, which can be slow for high-dimensional distributions.
In such cases, consider using MCMC methods to sample from the distribution instead.
!!! warning
The resulting ordered distribution is un-normalized, which can cause issues in some contexts, e.g. in
hierarchical models where the parameters of the ordered distribution are themselves sampled.
See the notes below for a more detailed discussion.
## Notes on `ordered` being un-normalized
The resulting ordered distribution is un-normalized. This is not a problem if used in a context where the
normalizing factor is irrelevant, but if the value of the normalizing factor impacts the resulting computation,
the results may be inaccurate.
For example, if the distribution is used in sampling a posterior distribution with MCMC and the parameters
of the ordered distribution are themselves sampled, then the normalizing factor would in general be needed
for accurate sampling, and `ordered` should not be used. However, if the parameters are fixed, then since
MCMC does not require distributions be normalized, `ordered` may be used without problems.
A common case is where the distribution being ordered is a joint distribution of `n` identical univariate
distributions. In this case the normalization factor works out to be the constant `n!`, and `ordered` can
again be used without problems even if the parameters of the univariate distribution are sampled.
"""
function ordered(d::ContinuousMultivariateDistribution)
# We're good if the map from unconstrained (in which we apply the ordered bijector)
# to constrained is monotonically increasing, i.e. order-preserving. In that case,
# we can form the ordered transformation as `binv ∘ OrderedBijector() ∘ b`.
# Similarly, if we're working with monotonically decreasing maps, we can do the same
# but with the addition of a sign flip before and after the ordered bijector.
b = bijector(d)
binv = inverse(b)
ordered_b = if is_monotonically_decreasing(binv)
SignFlip() inverse(OrderedBijector()) SignFlip() b
elseif is_monotonically_increasing(binv)
inverse(OrderedBijector()) b
else
throw(ArgumentError("ordered transform is currently not supported for $d."))
end

return OrderedDistribution(d, ordered_b)
end

bijector(d::OrderedDistribution) = d.transform

Base.eltype(::Type{<:OrderedDistribution{D}}) where {D} = eltype(D)
Base.eltype(d::OrderedDistribution) = eltype(d.dist)
function Distributions._logpdf(d::OrderedDistribution, x::AbstractVector{<:Real})
lp = Distributions.logpdf(d.dist, x)
issorted(x) && return lp
return oftype(lp, -Inf)
end
Base.length(d::OrderedDistribution) = length(d.dist)

function Distributions._rand!(
rng::AbstractRNG, d::OrderedDistribution, x::AbstractVector{<:Real}
)
# Rejection sampling.
while true
Distributions.rand!(rng, d.dist, x)
issorted(x) && return x
end
end
3 changes: 3 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log
# Matrix: single input.
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1]
_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1]

is_monotonically_increasing(a::Scale) = all(Base.Fix1(>, 0), a.a)
is_monotonically_decreasing(a::Scale) = all(Base.Fix1(<, 0), a.a)
3 changes: 3 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ _logabsdetjac_shift(a, x) = zero(eltype(x))
_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x)))

with_logabsdet_jacobian(b::Shift, x) = transform(b, x), logabsdetjac(b, x)

is_monotonically_increasing(::Shift) = true
is_monotonically_decreasing(::Shift) = true
33 changes: 33 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,36 @@ function truncated_logabsdetjac(x, a, b)
end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
true
elseif upperbounded
# => decreasing
false
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
true
else
# => some are unbounded and some are bounded
false
end
end
function is_monotonically_decreasing(b::TruncatedBijector)
lowerbounded, upperbounded = all(isfinite, b.lb), all(isfinite, b.ub)
return if lowerbounded
false
elseif upperbounded
# => decreasing
true
elseif all(!isfinite, b.lb) && all(!isfinite, b.ub)
# => all are unbounded so we have the identity
false
else
# => some are unbounded and some are bounded
true
end
end
63 changes: 63 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,69 @@ transform!(::typeof(identity), x, y) = copy!(y, x)
logabsdetjac(::typeof(identity), x) = zero(eltype(x))
logabsdetjac!(::typeof(identity), x, logjac) = logjac

###################
# Other utilities #
###################
"""
is_monotonically_increasing(f)
Returns `true` if `f` is monotonically increasing.
"""
is_monotonically_increasing(f) = false
is_monotonically_increasing(::typeof(identity)) = true
is_monotonically_increasing(binv::Inverse) = is_monotonically_increasing(inverse(binv))
is_monotonically_increasing(ef::Elementwise) = is_monotonically_increasing(ef.x)
function is_monotonically_increasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
#
# where `inc` means monotonically increasing, `dec` means monotonically decreasing,
# and `NA` means not applicable, i.e. we should return `false`.
return if is_monotonically_increasing(cf.inner)
is_monotonically_increasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_decreasing(cf.outer)
else
false
end
end

"""
is_monotonically_decreasing(f)
Returns `true` if `f` is monotonically decreasing.
"""
is_monotonically_decreasing(f) = false
is_monotonically_decreasing(::typeof(identity)) = false
is_monotonically_decreasing(binv::Inverse) = is_monotonically_decreasing(inverse(binv))
is_monotonically_decreasing(ef::Elementwise) = is_monotonically_decreasing(ef.x)
function is_monotonically_decreasing(cf::ComposedFunction)
# Here we have a few different cases:
#
# inner \ outer | inc | dec | other
# --------------+-----+-----+------
# inc | inc | dec | NA
# dec | dec | inc | NA
# other | NA | NA | NA
# --------------+-----+-----+------
#
# where `inc` means monotonically increasing, `dec` means monotonically decreasing,
# and `NA` means not applicable, i.e. we should return `false`.
return if is_monotonically_increasing(cf.inner)
is_monotonically_decreasing(cf.outer)
elseif is_monotonically_decreasing(cf.inner)
is_monotonically_increasing(cf.outer)
else
false
end
end

######################
# Bijectors includes #
######################
Expand Down
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand All @@ -11,14 +13,18 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "5"
AdvancedHMC = "0.6"
ChainRulesTestUtils = "0.7, 1"
ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
Expand All @@ -30,7 +36,9 @@ ForwardDiff = "0.10.12"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
LazyArrays = "1, 2"
LogDensityProblems = "2"
LogExpFunctions = "0.3.1"
MCMCDiagnosticTools = "0.3"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Zygote = "0.6.63"
Expand Down
52 changes: 46 additions & 6 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
using ChainRulesTestUtils: ChainRulesCore

# HACK: This is a workaround to test `Bijectors._inv_link_chol_lkj` which produces an
# upper-triangular `Matrix`, leading to `test_rrule` comaring the _full_ `Matrix`,
# including the lower-triangular part which potentially contains `undef` entries.
# Here we simply wrap the rrule we want to test to also convert to PD form, thus
# avoiding any issues with the lower-triangular part.
function _inv_link_chol_lkj_wrapper(y)
W, logJ = Bijectors._inv_link_chol_lkj(y)
return Bijectors.pd_from_upper(W), logJ
end
function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj_wrapper), y::AbstractVector)
(W, logJ), back = ChainRulesCore.rrule(Bijectors._inv_link_chol_lkj, y)
X, back_X = ChainRulesCore.rrule(Bijectors.pd_from_upper, W)
function pullback_inv_link_chol_lkj_wrapper((ΔX, ΔlogJ))
(_, ΔW) = back_X(ChainRulesCore.unthunk(ΔX))
(_, Δy) = back((ΔW, ΔlogJ))
return (ChainRulesCore.NoTangent(), Δy)
end
return (X, logJ), pullback_inv_link_chol_lkj_wrapper
end

@testset "chainrules" begin
x = randn()
y = expm1(randn())
Expand All @@ -22,11 +44,29 @@

# LKJ and LKJCholesky bijector
dist = LKJCholesky(3, 4)
x = rand(dist)
test_rrule(Bijectors._link_chol_lkj_from_upper, x.U)
test_rrule(Bijectors._link_chol_lkj_from_lower, x.L)
# Run multiple tests because we're working with `undef` entries, and so we
# want to make sure that we hit cases where the `undef` entries have different values.
# It's also just useful to test numerical stability for different realizations of `dist`.
for i in 1:30
x = rand(dist)
test_rrule(
Bijectors._link_chol_lkj_from_upper,
x.U;
testset_name="_link_chol_lkj_from_upper on $(typeof(x)) [$i]",
)
test_rrule(
Bijectors._link_chol_lkj_from_lower,
x.L;
testset_name="_link_chol_lkj_from_lower on $(typeof(x)) [$i]",
)

b = bijector(dist)
y = b(x)

b = bijector(dist)
y = b(x)
test_rrule(Bijectors._inv_link_chol_lkj, y)
test_rrule(
_inv_link_chol_lkj_wrapper,
y;
testset_name="_inv_link_chol_lkj on $(typeof(x)) [$i]",
)
end
end
Loading

2 comments on commit 026a07a

@yebai
Copy link
Member

@yebai yebai commented on 026a07a Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.13.14 already exists

Please sign in to comment.