Skip to content

Commit

Permalink
Fix for product of Dirichlet (#322)
Browse files Browse the repository at this point in the history
* Added default impl of `_logabdetjac_dist` so we can support non-batch
by default

* Added test for product of `Dirichlet`

* Bump patch version

* Update src/Bijectors.jl

Co-authored-by: David Widmann <[email protected]>

* Work around eachslice limitation on Julia <v1.9 (#323)

* Work around eachslice limitation on Julia <v1.9

* Bug fix

* Fix Tapir tests

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/ad/chainrules.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
5 people authored Aug 14, 2024
1 parent 7c8b533 commit 891c832
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.17"
version = "0.13.18"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
5 changes: 5 additions & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ invlink(d::Distribution, y) = inverse(bijector(d))(y)

# To still allow `logpdf_with_trans` to work with "batches" in a similar way
# as `logpdf` can.

# Default catch-all so we can work with distributions by default and batch-support can be
# added when needed.
_logabsdetjac_dist(d::Distribution, x) = logabsdetjac(bijector(d), x)

_logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x)
function _logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray)
return logabsdetjac.((bijector(d),), x)
Expand Down
22 changes: 16 additions & 6 deletions src/bijectors/product_bijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ inverse(b::ProductBijector) = ProductBijector(map(inverse, b.bs))

function _product_bijector_check_dim(::Val{N}, ::Val{M}) where {N,M}
if N > M
throw(
DimensionMismatch(
"Number of bijectors needs to be smaller than or equal to the number of dimensions",
),
)
msg = """
Number of bijectors needs to be smaller than or equal to the number of dimensions
"""
throw(DimensionMismatch(msg))
end
end

Expand All @@ -23,7 +22,18 @@ function _product_bijector_slices(

# If N < M, then the bijectors expect an input vector of dimension `M - N`.
# To achieve this, we need to slice along the last `N` dimensions.
return eachslice(x; dims=ntuple(i -> i + (M - N), N))
slice_indices = ntuple(i -> i + (M - N), N)
if VERSION >= v"1.9"
return eachslice(x; dims=slice_indices)
else
# Earlier Julia versions can't eachslice over multiple dimensions, so reshape the
# slice dimensions into a single one.
other_dims = tuple((size(x, i) for i in 1:(M - N))...)
slice_dims = tuple((size(x, i) for i in (1 + M - N):M)...)
x_reshaped = reshape(x, other_dims..., prod(slice_dims))
slices = eachslice(x_reshaped; dims=M - N + 1)
return reshape(collect(slices), slice_dims)
end
end

# Specialization for case where we're just applying elementwise.
Expand Down
33 changes: 27 additions & 6 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,35 @@ end

if @isdefined Tapir
rng = Xoshiro(123456)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none
Tapir.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none
Tapir.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
)
Tapir.TestUtils.test_rrule!!(
rng, Bijectors.find_alpha, x, y, UInt32(3); is_primitive=true, perf_flag=:none
Tapir.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
)
end

Expand Down
9 changes: 9 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ end
end
end

@testset "ProductDistribution" begin
d = product_distribution(fill(Dirichlet(ones(4)), 2, 3))
x = rand(d)
b = bijector(d)

@test logpdf_with_trans(d, x, false) == logpdf(d, x)
@test logpdf_with_trans(d, x, true) == logpdf(d, x) - logabsdetjac(b, x)
end

@testset "DistributionsAD" begin
@testset "$dist" for dist in [
filldist(Normal(), 2),
Expand Down

4 comments on commit 891c832

@mhauru
Copy link
Member

@mhauru mhauru commented on 891c832 Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Register Failed
@mhauru, it looks like you are not a publicly listed member/owner in the parent organization (TuringLang).
If you are a member/owner, you will need to change your membership to public. See GitHub Help

@mhauru
Copy link
Member

@mhauru mhauru commented on 891c832 Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113121

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.18 -m "<description of version>" 891c8324b12c7455635dbc7308a7b653961967ce
git push origin v0.13.18

Also, note the warning: Version 0.13.18 skips over 0.13.17
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

Please sign in to comment.