Skip to content

Commit 8a4790a

Browse files
torfjeldegithub-actions[bot]devmotion
authored
Vector-version for PDBijector (#271)
* initial work on PDVecBijector * added output_length and output_size to compute output, well, leengths and sizes for transformations * added tests for size of transformed dist using VcCorrBijector * use already constructed transfrormation * TransformedDistribution should now also have correct variate form * added proper variateform handling for VecCholeskyBijector too * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added output_size impl for Reshape too * added output_size for PDVecBijector annd tests * made bijector for PD distributions use PDVecBijcetor * bump minor version * Update src/bijectors/pd.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * move utilities from bijectors/corr.jl to utils.jl * fixed Tracker for PD matrices * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix for matrix AD tests * bumped patch version * revert patch version * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update src/utils.jl Co-authored-by: David Widmann <[email protected]> * removed unnecessary hacks for importing chainrules rule into ReverseDiff * markk triu_mask as non-differentiable * shiften some methods around to help with readability * removed redundant wrap_chainrules_output in BijectorsReverseDiffExt * renamed confusing name in pd tests --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent 03bdffb commit 8a4790a

8 files changed

+202
-126
lines changed

ext/BijectorsReverseDiffExt.jl

+2-20
Original file line numberDiff line numberDiff line change
@@ -250,32 +250,14 @@ end
250250
end
251251

252252
# `OrderedBijector`
253-
function _transform_ordered(y::Union{TrackedVector,TrackedMatrix})
254-
return track(_transform_ordered, y)
255-
end
256-
@grad function _transform_ordered(y::AbstractVecOrMat)
257-
x, dx = ChainRulesCore.rrule(_transform_ordered, value(y))
258-
return x, (wrap_chainrules_output Base.tail dx)
259-
end
260-
261-
function _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix})
262-
return track(_transform_inverse_ordered, x)
263-
end
264-
@grad function _transform_inverse_ordered(x::AbstractVecOrMat)
265-
y, dy = ChainRulesCore.rrule(_transform_inverse_ordered, value(x))
266-
return y, (wrap_chainrules_output Base.tail dy)
267-
end
253+
@grad_from_chainrules _transform_ordered(y::Union{TrackedVector,TrackedMatrix})
254+
@grad_from_chainrules _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix})
268255

269256
@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
270257

271258
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
272259
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)
273260

274-
# NOTE: Probably doesn't work in complete generality.
275-
wrap_chainrules_output(x) = x
276-
wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing
277-
wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
278-
279261
if VERSION <= v"1.8.0-DEV.1526"
280262
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
281263
# cholesky does not work with AbstractMatrix in julia versions before the compared one,

ext/BijectorsTrackerExt.jl

+17
Original file line numberDiff line numberDiff line change
@@ -532,4 +532,21 @@ wrap_chainrules_output(x) = x
532532
wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing
533533
wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
534534

535+
# `update_triu_from_vec`
536+
function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
537+
return track(Bijectors.update_triu_from_vec, vals, k, dim)
538+
end
539+
540+
@grad function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
541+
# HACK: This doesn't support higher order!
542+
y, dy = ChainRulesCore.rrule(Bijectors.update_triu_from_vec, data(vals), k, dim)
543+
return y, (wrap_chainrules_output Base.tail dy)
544+
end
545+
546+
Bijectors.upper_triangular(A::TrackedMatrix) = track(Bijectors.upper_triangular, A)
547+
@grad function Bijectors.upper_triangular(A::AbstractMatrix)
548+
Ad = data(A)
549+
return Bijectors.upper_triangular(Ad), Δ -> (Bijectors.upper_triangular(Δ),)
550+
end
551+
535552
end

src/bijectors/corr.jl

-91
Original file line numberDiff line numberDiff line change
@@ -89,97 +89,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
8989
return -logabsdetjac(inverse(b), (b(X)))
9090
end
9191

92-
"""
93-
triu_mask(X::AbstractMatrix, k::Int)
94-
95-
Return a mask for elements of `X` above the `k`th diagonal.
96-
"""
97-
function triu_mask(X::AbstractMatrix, k::Int)
98-
# Ensure that we're working with a square matrix.
99-
LinearAlgebra.checksquare(X)
100-
101-
# Using `similar` allows us to respect device of array, etc., e.g. `CuArray`.
102-
m = similar(X, Bool)
103-
return triu(.~m .| m, k)
104-
end
105-
106-
triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)]
107-
108-
function update_triu_from_vec!(
109-
vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real}
110-
)
111-
# Ensure that we're working with one-based indexing.
112-
# `triu` requires this too.
113-
LinearAlgebra.require_one_based_indexing(X)
114-
115-
# Set the values.
116-
idx = 1
117-
m, n = size(X)
118-
for j in 1:n
119-
for i in 1:min(j - k, m)
120-
X[i, j] = vals[idx]
121-
idx += 1
122-
end
123-
end
124-
125-
return X
126-
end
127-
128-
function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int)
129-
X = similar(vals, dim, dim)
130-
# TODO: Do we need this?
131-
X .= 0
132-
return update_triu_from_vec!(vals, k, X)
133-
end
134-
135-
function ChainRulesCore.rrule(
136-
::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int
137-
)
138-
function update_triu_from_vec_pullback(ΔX)
139-
return (
140-
ChainRulesCore.NoTangent(),
141-
triu_to_vec(ChainRulesCore.unthunk(ΔX), k),
142-
ChainRulesCore.NoTangent(),
143-
ChainRulesCore.NoTangent(),
144-
)
145-
end
146-
return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback
147-
end
148-
149-
# n * (n - 1) / 2 = d
150-
# ⟺ n^2 - n - 2d = 0
151-
# ⟹ n = (1 + sqrt(1 + 8d)) / 2
152-
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2
153-
154-
"""
155-
triu1_to_vec(X::AbstractMatrix{<:Real})
156-
157-
Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector.
158-
"""
159-
triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1)
160-
161-
inverse(::typeof(triu1_to_vec)) = vec_to_triu1
162-
163-
"""
164-
vec_to_triu1(x::AbstractVector{<:Real})
165-
166-
Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`.
167-
"""
168-
function vec_to_triu1(x::AbstractVector)
169-
n = _triu1_dim_from_length(length(x))
170-
X = update_triu_from_vec(x, 1, n)
171-
return upper_triangular(X)
172-
end
173-
174-
inverse(::typeof(vec_to_triu1)) = triu1_to_vec
175-
176-
function vec_to_triu1_row_index(idx)
177-
# Assumes that vector was saved in a column-major order
178-
# and that vector is one-based indexed.
179-
M = _triu1_dim_from_length(idx - 1)
180-
return idx - (M * (M - 1) ÷ 2)
181-
end
182-
18392
"""
18493
VecCorrBijector <: Bijector
18594

src/bijectors/pd.jl

+27
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,30 @@ end
4040
function with_logabsdet_jacobian(b::PDBijector, X)
4141
return transform(b, X), logabsdetjac(b, X)
4242
end
43+
44+
struct PDVecBijector <: Bijector end
45+
46+
transform(::PDVecBijector, X::AbstractMatrix{<:Real}) = pd_vec_link(X)
47+
pd_vec_link(X) = triu_to_vec(transpose(pd_link(X)))
48+
49+
function transform(::Inverse{PDVecBijector}, y::AbstractVector{<:Real})
50+
Y = permutedims(vec_to_triu(y))
51+
return transform(inverse(PDBijector()), Y)
52+
end
53+
54+
logabsdetjac(::PDVecBijector, X::AbstractMatrix{<:Real}) = logabsdetjac(PDBijector(), X)
55+
56+
function with_logabsdet_jacobian(b::PDVecBijector, X)
57+
return transform(b, X), logabsdetjac(b, X)
58+
end
59+
60+
function output_size(::PDVecBijector, sz::Tuple{Int,Int})
61+
n = first(sz)
62+
d = (n^2 + n) ÷ 2
63+
return (d,)
64+
end
65+
66+
function output_size(::Inverse{PDVecBijector}, sz::Tuple{Int})
67+
n = _triu_dim_from_length(first(sz))
68+
return (n, n)
69+
end

src/transformed_distribution.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ bijector(d::BoundedDistribution) = bijector_bounded(d)
8282
const LowerboundedDistribution = Union{Pareto,Levy}
8383
bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d)
8484

85-
bijector(d::PDMatDistribution) = PDBijector()
86-
bijector(d::MatrixBeta) = PDBijector()
85+
bijector(d::PDMatDistribution) = PDVecBijector()
86+
bijector(d::MatrixBeta) = PDVecBijector()
8787

8888
bijector(d::LKJ) = VecCorrBijector()
8989
bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo)

src/utils.jl

+120
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,123 @@ cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X)))
1818
cholesky_factor(X::Cholesky) = X.U
1919
cholesky_factor(X::UpperTriangular) = X
2020
cholesky_factor(X::LowerTriangular) = X
21+
22+
"""
23+
triu_mask(X::AbstractMatrix, k::Int)
24+
25+
Return a mask for elements of `X` above the `k`th diagonal.
26+
"""
27+
function triu_mask(X::AbstractMatrix, k::Int)
28+
# Ensure that we're working with a square matrix.
29+
LinearAlgebra.checksquare(X)
30+
31+
# Using `similar` allows us to respect device of array, etc., e.g. `CuArray`.
32+
m = similar(X, Bool)
33+
return triu!(fill!(m, true), k)
34+
end
35+
36+
ChainRulesCore.@non_differentiable triu_mask(X::AbstractMatrix, k::Int)
37+
38+
_triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)]
39+
40+
function update_triu_from_vec!(
41+
vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real}
42+
)
43+
# Ensure that we're working with one-based indexing.
44+
# `triu` requires this too.
45+
LinearAlgebra.require_one_based_indexing(X)
46+
47+
# Set the values.
48+
idx = 1
49+
m, n = size(X)
50+
for j in 1:n
51+
for i in 1:min(j - k, m)
52+
X[i, j] = vals[idx]
53+
idx += 1
54+
end
55+
end
56+
57+
return X
58+
end
59+
60+
function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int)
61+
X = similar(vals, dim, dim)
62+
# TODO: Do we need this?
63+
fill!(X, 0)
64+
return update_triu_from_vec!(vals, k, X)
65+
end
66+
67+
function ChainRulesCore.rrule(
68+
::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int
69+
)
70+
function update_triu_from_vec_pullback(ΔX)
71+
return (
72+
ChainRulesCore.NoTangent(),
73+
_triu_to_vec(ChainRulesCore.unthunk(ΔX), k),
74+
ChainRulesCore.NoTangent(),
75+
ChainRulesCore.NoTangent(),
76+
)
77+
end
78+
return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback
79+
end
80+
81+
# n * (n - 1) / 2 = d
82+
# ⟺ n^2 - n - 2d = 0
83+
# ⟹ n = (1 + sqrt(1 + 8d)) / 2
84+
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2
85+
86+
"""
87+
triu1_to_vec(X::AbstractMatrix{<:Real})
88+
89+
Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector.
90+
"""
91+
triu1_to_vec(X::AbstractMatrix) = _triu_to_vec(X, 1)
92+
93+
inverse(::typeof(triu1_to_vec)) = vec_to_triu1
94+
95+
"""
96+
vec_to_triu1(x::AbstractVector{<:Real})
97+
98+
Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`.
99+
"""
100+
function vec_to_triu1(x::AbstractVector)
101+
n = _triu1_dim_from_length(length(x))
102+
X = update_triu_from_vec(x, 1, n)
103+
return upper_triangular(X)
104+
end
105+
106+
inverse(::typeof(vec_to_triu1)) = triu1_to_vec
107+
108+
function vec_to_triu1_row_index(idx)
109+
# Assumes that vector was saved in a column-major order
110+
# and that vector is one-based indexed.
111+
M = _triu1_dim_from_length(idx - 1)
112+
return idx - (M * (M - 1) ÷ 2)
113+
end
114+
115+
# Triangular matrix with diagonals.
116+
117+
# (n^2 + n) / 2 = d
118+
# ⟺ n² + n - 2d = 0
119+
# ⟺ n = (-1 + sqrt(1 + 8d)) / 2
120+
_triu_dim_from_length(d) = (-1 + isqrt(1 + 8 * d)) ÷ 2
121+
122+
"""
123+
triu_to_vec(X::AbstractMatrix{<:Real})
124+
125+
Extracts elements from upper triangle of `X` and returns them as a vector.
126+
"""
127+
triu_to_vec(X::AbstractMatrix) = _triu_to_vec(X, 0)
128+
129+
"""
130+
vec_to_triu(x::AbstractVector{<:Real})
131+
132+
Constructs a matrix from a vector `x` by filling the upper triangle.
133+
"""
134+
function vec_to_triu(x::AbstractVector)
135+
n = _triu_dim_from_length(length(x))
136+
X = update_triu_from_vec(x, 0, n)
137+
return upper_triangular(X)
138+
end
139+
140+
inverse(::typeof(vec_to_triu)) = triu_to_vec

test/bijectors/pd.jl

+33-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,37 @@
11
using Bijectors, DistributionsAD, LinearAlgebra, Test
2-
using Bijectors: PDBijector
2+
using Bijectors: PDBijector, PDVecBijector
33

44
@testset "PDBijector" begin
5-
d = 5
6-
b = PDBijector()
7-
dist = Wishart(d, Matrix{Float64}(I, d, d))
8-
x = rand(dist)
9-
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
10-
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
11-
# Hence, we disable those tests.
12-
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
5+
for d in [2, 5]
6+
b = PDBijector()
7+
dist = Wishart(d, Matrix{Float64}(I, d, d))
8+
x = rand(dist)
9+
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
10+
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
11+
# Hence, we disable those tests.
12+
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
13+
end
14+
end
15+
16+
@testset "PDVecBijector" begin
17+
for d in [2, 5]
18+
b = PDVecBijector()
19+
dist = Wishart(d, Matrix{Float64}(I, d, d))
20+
x = rand(dist)
21+
y = b(x)
22+
23+
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
24+
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
25+
# Hence, we disable those tests.
26+
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
27+
28+
# Check that output sizes are computed correctly.
29+
tdist = transformed(dist, b)
30+
@test length(tdist) == length(y)
31+
@test tdist isa MultivariateDistribution
32+
33+
dist_transformed = transformed(MvNormal(zeros(length(tdist)), I), inverse(b))
34+
@test size(dist_transformed) == size(x)
35+
@test dist_transformed isa MatrixDistribution
36+
end
1337
end

test/transform.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,12 @@ end
183183

184184
x = rand(dist)
185185
x = x + x' + 2I
186-
lowerinds = [
187-
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]
188-
]
189186
upperinds = [
190187
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]
191188
]
192189
logpdf_turing = logpdf_with_trans(dist, x, true)
193190
J = ForwardDiff.jacobian(x -> link(dist, x), x)
194-
J = J[lowerinds, upperinds]
191+
J = J[:, upperinds]
195192
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
196193
end
197194
end

0 commit comments

Comments
 (0)