Skip to content

Commit

Permalink
Missing impl of with_logabsdet_jacobian for PDBijector (#245)
Browse files Browse the repository at this point in the history
* added missing implementation for with_logabsdet_jacobian and tests for
PDBijector

* added more informative error message in the case where
with_logabsdet_jacobian has not been implemented and transform and
logabsdetjac fail

* bump patch version

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* reverted a change

* fixed default impls of transform and logabsdetjac

* fixed logabsdetjac_pdbijector_chol

* qualified reference to logtwo

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion authored Feb 5, 2023
1 parent 622865b commit 93a0b16
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.12.0"
version = "0.12.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 1 addition & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using LinearAlgebra: AbstractTriangular

using InverseFunctions: InverseFunctions

import ChangesOfVariables: with_logabsdet_jacobian
import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian
import InverseFunctions: inverse

import ChainRulesCore
Expand Down
19 changes: 13 additions & 6 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@ function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real})
if !issuccess(Xcf)
Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I)
end
return logabsdetjac(b, Xcf)
return logabsdetjac_pdbijector_chol(Xcf)
end

function logabsdetjac(b::PDBijector, Xcf::Cholesky)
U = Xcf.U
T = eltype(U)
d = size(U, 1)
return - sum((d .- (1:d) .+ 2) .* log.(diag(U))) - d * log(T(2))
function logabsdetjac_pdbijector_chol(Xcf::Cholesky)
# NOTE: Use `UpperTriangular` here because we only need `diag(U)`
# and `UL` is by default already constructed in `Cholesky`.
UL = Xcf.UL
d = size(UL, 1)
z = sum(((d + 1):(-1):2) .* log.(diag(UL)))
return - (z + d * oftype(z, IrrationalConstants.logtwo))
end

# TODO: Implement explicitly.
function with_logabsdet_jacobian(b::PDBijector, X)
return transform(b, X), logabsdetjac(b, X)
end
18 changes: 16 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ Broadcast.broadcastable(b::Transform) = Ref(b)
Transform `x` using `b`, treating `x` as a single input.
"""
transform(f::F, x) where {F<:Function} = f(x)
transform(t::Transform, x) = first(with_logabsdet_jacobian(t, x))
function transform(t::Transform, x)
res = with_logabsdet_jacobian(t, x)
if res isa ChangesOfVariables.NoLogAbsDetJacobian
error("`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.")
end

return first(res)
end

"""
transform!(b, x[, y])
Expand All @@ -73,7 +80,14 @@ transform!(b, x, y) = copyto!(y, transform(b, x))
Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`.
"""
logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x))
function logabsdetjac(b, x)
res = with_logabsdet_jacobian(b, x)
if res isa ChangesOfVariables.NoLogAbsDetJacobian
error("`logabsdetjac` not implemented for $(typeof(b)); implement `logabsdetjac` and/or `with_logabsdet_jacobian`.")
end

return last(res)
end

"""
logabsdetjac!(b, x[, logjac])
Expand Down
13 changes: 13 additions & 0 deletions test/bijectors/pd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Bijectors, DistributionsAD, LinearAlgebra, Test
using Bijectors: PDBijector

@testset "PDBijector" begin
d = 5
b = PDBijector()
dist = Wishart(d, Matrix{Float64}(I, d, d))
x = rand(dist)
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
# Hence, we disable those tests.
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
end
20 changes: 16 additions & 4 deletions test/bijectors/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ function test_bijector(
logjac=nothing,
test_not_identity=isnothing(y) && isnothing(logjac),
test_types=false,
changes_of_variables_test=true,
inverse_functions_test=true,
compare=isapprox,
kwargs...
)
Expand All @@ -29,12 +31,22 @@ function test_bijector(
end

# ChangesOfVariables.jl
ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...)
ChangesOfVariables.test_with_logabsdet_jacobian(ib, isnothing(y) ? y_test : y, getjacobian; compare=compare, kwargs...)
# For non-bijective transformations, these tests always fail since determinant of
# the Jacobian is zero. Hence we allow the caller to disable them if necessary.
if changes_of_variables_test
ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...)
ChangesOfVariables.test_with_logabsdet_jacobian(
ib, isnothing(y) ? y_test : y, getjacobian;
compare=compare,
kwargs...
)
end

# InverseFunctions.jl
InverseFunctions.test_inverse(b, x; compare, kwargs...)
InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...)
if inverse_functions_test
InverseFunctions.test_inverse(b, x; compare, kwargs...)
InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...)
end

# Always want the following to hold
@test compare(ires[1], x; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ if GROUP == "All" || GROUP == "Interface"
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/ordered.jl")
include("bijectors/pd.jl")
end

if GROUP == "All" || GROUP == "AD"
Expand Down

2 comments on commit 93a0b16

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/77056

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.1 -m "<description of version>" 93a0b16c7986d5a483e5221aa47f1314b75b0151
git push origin v0.12.1

Please sign in to comment.