Skip to content

Commit

Permalink
Fix type instabilities of SimplexBijector (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Apr 5, 2021
1 parent 2dee76c commit f8db035
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.16"
version = "0.9.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
20 changes: 10 additions & 10 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,31 +168,31 @@ isdirichlet(::Distribution) = false
function link(
d::Dirichlet,
x::AbstractVecOrMat{<:Real},
proj::Bool = true,
)
::Val{proj}=Val(true),
) where {proj}
return SimplexBijector{proj}()(x)
end

function link_jacobian(
d::Dirichlet,
x::AbstractVector{T},
proj::Bool = true,
) where {T<:Real}
x::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(SimplexBijector{proj}(), x)
end

function invlink(
d::Dirichlet,
y::AbstractVecOrMat{<:Real},
proj::Bool = true
)
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet,
y::AbstractVector{T},
proj::Bool = true
) where {T<:Real}
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
end

Expand Down
24 changes: 11 additions & 13 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
####################
struct SimplexBijector{N, T} <: Bijector{N} end
SimplexBijector() = SimplexBijector{1}()
function SimplexBijector{N}() where {N}
if N isa Bool
SimplexBijector{1, N}()
else
SimplexBijector{N, true}()
end
end
SimplexBijector{N}() where {N} = SimplexBijector{N,true}()

# Special case `N = 1`
SimplexBijector{true}() = SimplexBijector{1,true}()
SimplexBijector{false}() = SimplexBijector{1,false}()

(b::SimplexBijector{1})(x::AbstractVector) = _simplex_bijector(x, b)
(b::SimplexBijector{1})(y::AbstractVector, x::AbstractVector) = _simplex_bijector!(y, x, b)
Expand Down Expand Up @@ -276,8 +274,8 @@ end

function simplex_link_jacobian(
x::AbstractVector{T},
proj::Bool = true,
) where {T <: Real}
::Val{proj}=Val(true),
) where {T<:Real, proj}
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
dydxt = similar(x, length(x), length(x))
Expand Down Expand Up @@ -306,7 +304,7 @@ function simplex_link_jacobian(
return UpperTriangular(dydxt)'
end
function jacobian(b::SimplexBijector{1, proj}, x::AbstractVector{T}) where {proj, T}
return simplex_link_jacobian(x, proj)
return simplex_link_jacobian(x, Val(proj))
end

#=
Expand Down Expand Up @@ -377,8 +375,8 @@ end

function simplex_invlink_jacobian(
y::AbstractVector{T},
proj::Bool = true,
) where {T <: Real}
::Val{proj}=Val(true),
) where {T<:Real, proj}
K = length(y)
@assert K > 1 "x needs to be of length greater than 1"
dxdy = similar(y, length(y), length(y))
Expand Down Expand Up @@ -428,7 +426,7 @@ function simplex_invlink_jacobian(
end
# jacobian
function jacobian(ib::Inverse{<:SimplexBijector{1, proj}}, y::AbstractVector{T}) where {proj, T}
return simplex_invlink_jacobian(y, proj)
return simplex_invlink_jacobian(y, Val(proj))
end

#=
Expand Down
20 changes: 10 additions & 10 deletions src/compat/distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,31 @@ isdirichlet(::TuringDirichlet) = true
function link(
d::TuringDirichlet,
x::AbstractVecOrMat{<:Real},
proj::Bool = true,
)
::Val{proj}=Val(true),
) where {proj}
return SimplexBijector{proj}()(x)
end

function link_jacobian(
d::TuringDirichlet,
x::AbstractVector{T},
proj::Bool = true,
) where {T<:Real}
x::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(SimplexBijector{proj}(), x)
end

function invlink(
d::TuringDirichlet,
y::AbstractVecOrMat{<:Real},
proj::Bool = true
)
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::TuringDirichlet,
y::AbstractVector{T},
proj::Bool = true
) where {T<:Real}
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
end

Expand Down
40 changes: 20 additions & 20 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@ function single_sample_tests(dist)

# Check that invlink is inverse of link.
x = rand(dist)
@test invlink(dist, link(dist, copy(x))) x atol=1e-9
@test @inferred(invlink(dist, link(dist, copy(x)))) x atol=1e-9

# Check that link is inverse of invlink. Hopefully this just holds given the above...
y = link(dist, x)
y = @inferred(link(dist, x))
if dist isa Dirichlet
# `logit` and `logistic` are not perfect inverses. This leads to a diversion.
# Example:
# julia> logit(logistic(0.9999999999999998))
# 1.0
# julia> logistic(logit(0.9999999999999998))
# 0.9999999999999998
@test link(dist, invlink(dist, copy(y))) y atol=0.5
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol=0.5
else
@test link(dist, invlink(dist, copy(y))) y atol=1e-9
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol=1e-9
end
if dist isa SimplexDistribution
# This should probably be exact.
Expand All @@ -72,9 +72,9 @@ end
# univariate distributions, just a vector of identical values. For vector-valued
# distributions, a matrix whose columns are identical.
function multi_sample_tests(dist, x, xs, N)
ys = link(dist, copy(xs))
@test invlink(dist, link(dist, copy(xs))) xs atol=1e-9
@test link(dist, invlink(dist, copy(ys))) ys atol=1e-9
ys = @inferred(link(dist, copy(xs)))
@test @inferred(invlink(dist, link(dist, copy(xs)))) xs atol=1e-9
@test @inferred(link(dist, invlink(dist, copy(ys)))) ys atol=1e-9
@test logpdf_with_trans(dist, xs, true) == fill(logpdf_with_trans(dist, x, true), N)
@test logpdf_with_trans(dist, xs, false) == fill(logpdf_with_trans(dist, x, false), N)

Expand Down Expand Up @@ -147,7 +147,7 @@ let ϵ = eps(Float64)
# This should fail at the minute. Not sure what the correct way to test this is.
x = rand(dist)
logpdf_turing = logpdf_with_trans(dist, x, true)
J = jacobian(x->link(dist, x, false), x)
J = jacobian(x->link(dist, x, Val(false)), x)
@test logpdf(dist, x .+ ϵ) - _logabsdet(J) logpdf_turing

# Issue #12
Expand Down Expand Up @@ -264,18 +264,18 @@ end
function test_link_and_invlink()
dist = Dirichlet(4, 4)
x = rand(dist)
y = link(dist, x)

f1 = x -> link(dist, x, true)
f2 = x -> link(dist, x, false)
g1 = y -> invlink(dist, y, true)
g2 = y -> invlink(dist, y, false)

@test @aeq jacobian(f1, x) Bijectors.simplex_link_jacobian(x, true)
@test @aeq jacobian(f2, x) Bijectors.simplex_link_jacobian(x, false)
@test @aeq jacobian(g1, y) Bijectors.simplex_invlink_jacobian(y, true)
@test @aeq jacobian(g2, y) Bijectors.simplex_invlink_jacobian(y, false)
@test @aeq Bijectors.simplex_link_jacobian(x, false) * Bijectors.simplex_invlink_jacobian(y, false) I
y = @inferred(link(dist, x))

f1 = x -> link(dist, x, Val(true))
f2 = x -> link(dist, x, Val(false))
g1 = y -> invlink(dist, y, Val(true))
g2 = y -> invlink(dist, y, Val(false))

@test @aeq jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true)))
@test @aeq jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false)))
@test @aeq jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true)))
@test @aeq jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false)))
@test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * Bijectors.simplex_invlink_jacobian(y, Val(false)) I
end
for i in 1:4
test_link_and_invlink()
Expand Down

2 comments on commit f8db035

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33586

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.0 -m "<description of version>" f8db035e765c923476f4c742f75dbac753ff71ed
git push origin v0.9.0

Please sign in to comment.