Skip to content

Commit b204712

Browse files
oschulzdevmotion
andauthored
Use ChangesOfVariables and InverseFunctions (#212)
* Add ChangesOfVariables and InverseFunctions to deps * Replace forward by with_logabsdet_jacobian * Replace Base.inv with InverseFunctions.inverse * Improve deprecation scheme for forward Co-authored-by: David Widmann <[email protected]> * Improve deprecation scheme for inv * Test forward and inv deprecations * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Fixes regarding with_logabsdet_jacobian and inverse * Fix with_logabsdet_jacobian for NamedComposition * Fix deprecation of inv * Use inverse instead of inv for Composed * Use with_logabsdet_jacobian instead of forward * Workaround for intermittent failures in Dirichlet test * Use with_logabsdet_jacobian instead of forward * Use with_logabsdet_jacobian instead of forward * Add rrules for combine with PartitionMask Zygote-generated pullback for `combine(m::PartitionMask, x_1, x_2, x_3)` fails with `no method matching zero(::Type{Nothing})`. * Use inv instead of inverse for numbers * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Whitespace fix. Co-authored-by: David Widmann <[email protected]> * Move combine rrule and add test * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Use @test_deprecated Co-authored-by: David Widmann <[email protected]> * Use @test_deprecated Co-authored-by: David Widmann <[email protected]> * Use inverse instead of inv * Use test_inverse and test_with_logabsdet_jacobian * Use inverse instead of inv * Increase version number to v0.9.12 * Reexport with_logabsdet_jacobian and inverse * Increase package version to v0.10.0 Co-authored-by: David Widmann <[email protected]>
1 parent 31b1c38 commit b204712

33 files changed

+444
-378
lines changed

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.9.11"
3+
version = "0.10.0"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1011
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
12+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1113
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1214
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1315
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
@@ -22,9 +24,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2224
[compat]
2325
ArgCheck = "1, 2"
2426
ChainRulesCore = "0.10.11, 1"
27+
ChangesOfVariables = "0.1"
2528
Compat = "3"
2629
Distributions = "0.23.3, 0.24, 0.25"
2730
Functors = "0.1, 0.2"
31+
InverseFunctions = "0.1"
2832
IrrationalConstants = "0.1"
2933
LogExpFunctions = "0.3.3"
3034
MappedArrays = "0.2.2, 0.3, 0.4"

README.md

+40-38
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ The following table lists mathematical operations for a bijector and the corresp
1818

1919
| Operation | Method | Automatic |
2020
|:------------------------------------:|:-----------------:|:-----------:|
21-
| `b ↦ b⁻¹` | `inv(b)` ||
21+
| `b ↦ b⁻¹` | `inverse(b)` ||
2222
| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` ||
2323
| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` ||
2424
| `x ↦ b(x)` | `b(x)` | × |
25-
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
25+
| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × |
2626
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
27-
| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` ||
27+
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` ||
2828
| `p ↦ q := b_* p` | `q = transformed(p, b)` ||
2929
| `y ∼ q` | `y = rand(q)` ||
3030
| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` ||
@@ -123,7 +123,7 @@ true
123123
What about `invlink`?
124124

125125
```julia
126-
julia> b⁻¹ = inv(b)
126+
julia> b⁻¹ = inverse(b)
127127
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
128128

129129
julia> b⁻¹(y)
@@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y)
133133
true
134134
```
135135

136-
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
136+
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true.
137137

138138
#### Dimensionality
139139
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
@@ -162,7 +162,7 @@ true
162162
And since `Composed isa Bijector`:
163163

164164
```julia
165-
julia> id_x = inv(id_y)
165+
julia> id_x = inverse(id_y)
166166
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))
167167

168168
julia> id_x(x) x
@@ -199,9 +199,9 @@ julia> logpdf_forward(td, x)
199199
-1.123311289915276
200200
```
201201

202-
#### `logabsdetjac` and `forward`
202+
#### `logabsdetjac` and `with_logabsdet_jacobian`
203203

204-
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
204+
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
205205

206206
```julia
207207
julia> logabsdetjac(b⁻¹, y)
@@ -218,21 +218,21 @@ julia> logabsdetjac(b, x) ≈ -logabsdetjac(b⁻¹, y)
218218
true
219219
```
220220

221-
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use:
221+
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `with_logabsdet_jacobian` comes to good use:
222222

223223
```julia
224-
julia> forward(b, x)
225-
(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655)
224+
julia> with_logabsdet_jacobian(b, x)
225+
(-0.5369949942509267, 1.4575353795716655)
226226
```
227227

228228
Similarily
229229

230230
```julia
231-
julia> forward(inv(b), y)
232-
(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655)
231+
julia> with_logabsdet_jacobian(inverse(b), y)
232+
(0.3688868996596376, -1.4575353795716655)
233233
```
234234

235-
In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented).
235+
In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented).
236236

237237
#### Sampling from `TransformedDistribution`
238238
At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`:
@@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality.
241241
julia> y = rand(td) # ∈ ℝ
242242
0.999166054552483
243243

244-
julia> x = inv(td.transform)(y) # transform back to interval [0, 1]
244+
julia> x = inverse(td.transform)(y) # transform back to interval [0, 1]
245245
0.7308945834125756
246246
```
247247

@@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0)
261261
julia> b = bijector(dist) # (0, 1) → ℝ
262262
Logit{Float64}(0.0, 1.0)
263263

264-
julia> b⁻¹ = inv(b) # ℝ → (0, 1)
264+
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
265265
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))
266266

267267
julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
@@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while
280280
```julia
281281
td = transformed(Beta())
282282

283-
inv(td.transform)(rand(td))
283+
inverse(td.transform)(rand(td))
284284
```
285285

286286
will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._
@@ -335,7 +335,7 @@ julia> # Construct the transform
335335
bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists
336336
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())
337337

338-
julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
338+
julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained
339339
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))
340340

341341
julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
@@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo
411411
```julia
412412
julia> d = MvNormal(zeros(2), ones(2));
413413

414-
julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));
414+
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));
415415

416416
julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
417417
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
@@ -481,7 +481,7 @@ julia> Flux.params(flow)
481481
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
482482
```
483483

484-
Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
484+
Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
485485

486486
```julia
487487
julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass
@@ -542,41 +542,43 @@ Logit{Float64}(0.0, 1.0)
542542
julia> b(0.6)
543543
0.4054651081081642
544544

545-
julia> inv(b)(y)
545+
julia> inverse(b)(y)
546546
Tracked 2-element Array{Float64,1}:
547547
0.3078149833748082
548548
0.72380041667891
549549

550550
julia> logabsdetjac(b, 0.6)
551551
1.4271163556401458
552552

553-
julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))`
553+
julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))`
554554
Tracked 2-element Array{Float64,1}:
555555
-1.546158373866469
556556
-1.6098711387913573
557557

558-
julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`
559-
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
558+
julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))`
559+
(0.4054651081081642, 1.4271163556401458)
560560
```
561561

562-
For further efficiency, one could manually implement `forward(b::Logit, x)`:
562+
For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`:
563563

564564
```julia
565-
julia> import Bijectors: forward, Logit
565+
julia> using Bijectors: Logit
566566

567-
julia> function forward(b::Logit{<:Real}, x)
567+
julia> import Bijectors: with_logabsdet_jacobian
568+
569+
julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x)
568570
totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not
569571
y = logit.(totally_worth_saving)
570572
logjac = @. - log((b.b - x) * totally_worth_saving)
571-
return (rv=y, logabsdetjac = logjac)
573+
return (y, logjac)
572574
end
573575
forward (generic function with 16 methods)
574576

575-
julia> forward(b, 0.6)
576-
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
577+
julia> with_logabsdet_jacobian(b, 0.6)
578+
(0.4054651081081642, 1.4271163556401458)
577579

578-
julia> @which forward(b, 0.6)
579-
forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
580+
julia> @which with_logabsdet_jacobian(b, 0.6)
581+
with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
580582
```
581583
582584
As you can see it's a very contrived example, but you get the idea.
@@ -613,10 +615,10 @@ julia> logabsdetjac(b_ad, 0.6)
613615
julia> y = b_ad(0.6)
614616
0.4054651081081642
615617

616-
julia> inv(b_ad)(y)
618+
julia> inverse(b_ad)(y)
617619
0.6
618620

619-
julia> logabsdetjac(inv(b_ad), y)
621+
julia> logabsdetjac(inverse(b_ad), y)
620622
-1.4271163556401458
621623
```
622624
@@ -665,7 +667,7 @@ help?> Bijectors.Composed
665667

666668
A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively.
667669

668-
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv.
670+
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse.
669671

670672
If you want to use an Array as the container instead you can do
671673

@@ -713,9 +715,9 @@ The distribution interface consists of:
713715
#### Methods
714716
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
715717
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
716-
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
718+
- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
717719
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
718-
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
720+
- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
719721
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
720722
- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency.
721723
- `dimension(b::Bijector)`: returns the dimensionality of `b`.
@@ -725,7 +727,7 @@ For `TransformedDistribution`, together with default implementations for `Distri
725727
- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d`
726728
- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`.
727729
- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand.
728-
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient.
730+
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient.
729731
730732
# Bibliography
731733
1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6).

src/Bijectors.jl

+15-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ using MappedArrays
3535
using Base.Iterators: drop
3636
using LinearAlgebra: AbstractTriangular
3737

38+
import ChangesOfVariables: with_logabsdet_jacobian
39+
import InverseFunctions: inverse
40+
3841
import ChainRulesCore
3942
import Functors
4043
import IrrationalConstants
@@ -51,6 +54,8 @@ export TransformDistribution,
5154
logpdf_with_trans,
5255
isclosedform,
5356
transform,
57+
with_logabsdet_jacobian,
58+
inverse,
5459
forward,
5560
logabsdetjac,
5661
logabsdetjacinv,
@@ -121,7 +126,7 @@ end
121126
# Distributions
122127

123128
link(d::Distribution, x) = bijector(d)(x)
124-
invlink(d::Distribution, y) = inv(bijector(d))(y)
129+
invlink(d::Distribution, y) = inverse(bijector(d))(y)
125130
function logpdf_with_trans(d::Distribution, x, transform::Bool)
126131
if ispd(d)
127132
return pd_logpdf_with_trans(d, x, transform)
@@ -188,14 +193,14 @@ function invlink(
188193
y::AbstractVecOrMat{<:Real},
189194
::Val{proj}=Val(true),
190195
) where {proj}
191-
return inv(SimplexBijector{proj}())(y)
196+
return inverse(SimplexBijector{proj}())(y)
192197
end
193198
function invlink_jacobian(
194199
d::Dirichlet,
195200
y::AbstractVector{<:Real},
196201
::Val{proj}=Val(true),
197202
) where {proj}
198-
return jacobian(inv(SimplexBijector{proj}()), y)
203+
return jacobian(inverse(SimplexBijector{proj}()), y)
199204
end
200205

201206
## Matrix
@@ -249,6 +254,13 @@ include("utils.jl")
249254
include("interface.jl")
250255
include("chainrules.jl")
251256

257+
Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))
258+
259+
@noinline function Base.inv(b::AbstractBijector)
260+
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv)
261+
inverse(b)
262+
end
263+
252264
# Broadcasting here breaks Tracker for some reason
253265
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
254266
maporbroadcast(f, x::AbstractArray...) = f.(x...)

0 commit comments

Comments
 (0)