Skip to content

Commit 3e0d765

Browse files
authored
Define adjoints for inverse of PlanarLayer (#160)
1 parent 827b80a commit 3e0d765

13 files changed

+100
-30
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.8.12"
3+
version = "0.8.13"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -19,6 +20,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1920

2021
[compat]
2122
ArgCheck = "1, 2"
23+
ChainRulesCore = "0.9"
2224
Compat = "3"
2325
Distributions = "0.23.3, 0.24"
2426
MappedArrays = "0.2.2, 0.3"

src/Bijectors.jl

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using MappedArrays
3636
using Base.Iterators: drop
3737
using LinearAlgebra: AbstractTriangular
3838
import NonlinearSolve
39+
import ChainRulesCore
3940

4041
export TransformDistribution,
4142
PositiveDistribution,
@@ -243,6 +244,7 @@ end
243244

244245
include("utils.jl")
245246
include("interface.jl")
247+
include("chainrules.jl")
246248

247249
# Broadcasting here breaks Tracker for some reason
248250
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)

src/bijectors/planar_layer.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,16 @@ D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
151151
arXiv:1505.05770
152152
"""
153153
function find_alpha(wt_y::Real, wt_u_hat::Real, b::Real)
154-
# Compute the initial bracket
155-
_wt_y, _wt_u_hat, _b = promote(wt_y, wt_u_hat, b)
156-
initial_bracket = (_wt_y - abs(_wt_u_hat), _wt_y + abs(_wt_u_hat))
154+
# avoid promotions in root-finding algorithm and simplify AD dispatches
155+
return find_alpha(promote(wt_y, wt_u_hat, b)...)
156+
end
157+
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
158+
# Compute the initial bracket (see above).
159+
initial_bracket = (wt_y - abs(wt_u_hat), wt_y + abs(wt_u_hat))
157160

158161
# Try to solve the root-finding problem, i.e., compute a final bracket
159162
prob = NonlinearSolve.NonlinearProblem{false}(initial_bracket) do α, _
160-
α + _wt_u_hat * tanh+ _b) - _wt_y
163+
α + wt_u_hat * tanh+ b) - wt_y
161164
end
162165
sol = NonlinearSolve.solve(prob, NonlinearSolve.Falsi())
163166
if sol.retcode === NonlinearSolve.MAXITERS_EXCEED

src/chainrules.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# differentation rule for the iterative algorithm in the inverse of `PlanarLayer`
2+
ChainRulesCore.@scalar_rule(
3+
find_alpha(wt_y::Real, wt_u_hat::Real, b::Real),
4+
@setup(
5+
x = inv(1 + wt_u_hat * sech+ b)^2),
6+
),
7+
(x, - tanh+ b) * x, x - 1),
8+
)

src/compat/reversediff.jl

+13
Original file line numberDiff line numberDiff line change
@@ -181,5 +181,18 @@ lower(A::TrackedMatrix) = track(lower, A)
181181
return lower(Ad), Δ -> (lower(Δ),)
182182
end
183183

184+
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
185+
return track(find_alpha, wt_y, wt_u_hat, b)
186+
end
187+
@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal)
188+
α = find_alpha(data(wt_y), data(wt_u_hat), data(b))
189+
190+
∂wt_y = inv(1 + wt_u_hat * sech+ b)^2)
191+
∂wt_u_hat = - tanh+ b) * ∂wt_y
192+
∂b = ∂wt_y - 1
193+
find_alpha_pullback::Real) =* ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b)
194+
195+
return α, find_alpha_pullback
196+
end
184197

185198
end

src/compat/tracker.jl

+14
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,17 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
446446

447447
return z, pullback_link_chol_lkj
448448
end
449+
450+
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
451+
return track(find_alpha, wt_y, wt_u_hat, b)
452+
end
453+
@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal)
454+
α = find_alpha(data(wt_y), data(wt_u_hat), data(b))
455+
456+
∂wt_y = inv(1 + wt_u_hat * sech+ b)^2)
457+
∂wt_u_hat = - tanh+ b) * ∂wt_y
458+
∂b = ∂wt_y - 1
459+
find_alpha_pullback::Real) =* ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b)
460+
461+
return α, find_alpha_pullback
462+
end

src/interface.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ has a closed-form implementation.
5151
5252
Most bijectors have closed-form evaluations, but there are cases where
5353
this is not the case. For example the *inverse* evaluation of `PlanarLayer`
54-
requires an iterative procedure to evaluate and thus is not differentiable.
54+
requires an iterative procedure to evaluate.
5555
"""
5656
isclosedform(b::Bijector) = true
5757

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
23
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
34
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
45
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
@@ -12,6 +13,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1213
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1314

1415
[compat]
16+
ChainRulesTestUtils = "0.5"
1517
Combinatorics = "1.0.2"
1618
DistributionsAD = "0.6.3"
1719
FiniteDifferences = "0.11, 0.12"

test/ad/chainrules.jl

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@testset "chainrules" begin
2+
x, Δx, x̄ = randn(3)
3+
y, Δy, ȳ = randn(3)
4+
z, Δz, z̄ = randn(3)
5+
Δu = randn()
6+
7+
= expm1(y)
8+
frule_test(Bijectors.find_alpha, (x, Δx), (ỹ, Δy), (z, Δz); rtol=1e-3, atol=1e-3)
9+
rrule_test(Bijectors.find_alpha, Δu, (x, x̄), (ỹ, ȳ), (z, z̄); rtol=1e-3, atol=1e-3)
10+
end

test/ad/flows.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testset "PlanarLayer" begin
2+
# logpdf of a flow with a planar layer and two-dimensional inputs
3+
test_ad(randn(7)) do θ
4+
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
5+
flow = transformed(MvNormal(2, 1), layer)
6+
return logpdf_forward(flow, θ[6:7])
7+
end
8+
test_ad(randn(11)) do θ
9+
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
10+
flow = transformed(MvNormal(2, 1), layer)
11+
return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :)))
12+
end
13+
14+
# logpdf of a flow with the inverse of a planar layer and two-dimensional inputs
15+
test_ad(randn(7)) do θ
16+
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
17+
flow = transformed(MvNormal(2, 1), inv(layer))
18+
return logpdf_forward(flow, θ[6:7])
19+
end
20+
test_ad(randn(11)) do θ
21+
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
22+
flow = transformed(MvNormal(2, 1), inv(layer))
23+
return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :)))
24+
end
25+
end

test/bijectors/utils.jl

+10-22
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,10 @@ function test_bijector(
131131
test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...)
132132

133133
# Test AD
134-
if isclosedform(b)
135-
test_ad(x -> b(first(x)), [x_true, ])
136-
end
134+
test_ad(x -> b(first(x)), [x_true, ])
137135

138-
if isclosedform(ib)
139-
y = b(x_true)
140-
test_ad(x -> ib(first(x)), [y, ])
141-
end
136+
y = b(x_true)
137+
test_ad(x -> ib(first(x)), [y, ])
142138

143139
test_ad(x -> logabsdetjac(b, first(x)), [x_true, ])
144140
end
@@ -167,28 +163,20 @@ function test_bijector(
167163
test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...)
168164

169165
# Test AD
170-
if isclosedform(b)
171-
test_ad(x -> sum(b(x)), collect(x_true))
172-
end
173-
if isclosedform(ib)
174-
y = b(x_true)
175-
test_ad(x -> sum(ib(x)), y)
176-
end
166+
test_ad(x -> sum(b(x)), collect(x_true))
167+
y = b(x_true)
168+
test_ad(x -> sum(ib(x)), y)
177169

178170
test_ad(x -> logabsdetjac(b, x), x_true)
179171
end
180172
end
181173

182174
function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6)
183-
if isclosedform(b)
184-
logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)]
185-
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
186-
end
175+
logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)]
176+
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
187177
end
188178

189179
function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6)
190-
if isclosedform(b)
191-
logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs]
192-
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
193-
end
180+
logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs]
181+
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
194182
end

test/norm_flows.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ end
2727
our_method = sum(forward(flow, z).logabsdetjac)
2828

2929
@test our_method forward_diff
30-
@test inv(flow)(flow(z)) z rtol=0.25
31-
@test (inv(flow) flow)(z) z rtol=0.25
30+
@test inv(flow)(flow(z)) z
31+
@test (inv(flow) flow)(z) z
3232
end
3333

3434
w = ones(10)

test/runtests.jl

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Bijectors
22

3+
using ChainRulesTestUtils
34
using Combinatorics
45
using DistributionsAD
56
using FiniteDifferences
@@ -35,6 +36,8 @@ if GROUP == "All" || GROUP == "Interface"
3536
end
3637

3738
if GROUP == "All" || GROUP == "AD"
39+
include("ad/chainrules.jl")
40+
include("ad/flows.jl")
3841
include("ad/distributions.jl")
3942
end
4043

0 commit comments

Comments
 (0)