Skip to content

Commit

Permalink
Variable transformations (#582)
Browse files Browse the repository at this point in the history
* Add transforms docs

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Fix missing imports

* Add missing images

* Update Manifest

* Re-seed sampling

* Minor text fixes

* Changes as per @willtebbutt's suggestions

Co-authored-by: Will Tebbutt <[email protected]>

* Extra details about rejection vs transformation

* Changes per @mhauru's comments

Co-authored-by: Markus Hauru <[email protected]>

* Fix Mermaid SVG CSS

* Add a concluding paragraph

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Will Tebbutt <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
4 people authored Feb 3, 2025
1 parent 523428a commit 2a44076
Show file tree
Hide file tree
Showing 7 changed files with 1,091 additions and 18 deletions.
10 changes: 10 additions & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ website:
- developers/compiler/minituring-contexts/index.qmd
- developers/compiler/design-overview/index.qmd

- section: "Variable Transformations"
collapse-level: 1
contents:
- developers/transforms/distributions/index.qmd
- developers/transforms/bijectors/index.qmd
- developers/transforms/dynamicppl/index.qmd

- section: "Inference (note: outdated)"
collapse-level: 1
contents:
Expand Down Expand Up @@ -195,3 +202,6 @@ using-turing-abstractmcmc: developers/inference/abstractmcmc-turing
using-turing-interface: developers/inference/abstractmcmc-interface
using-turing-variational-inference: developers/inference/variational-inference
using-turing-implementing-samplers: developers/inference/implementing-samplers
dev-transforms-distributions: developers/transforms/distributions
dev-transforms-bijectors: developers/transforms/bijectors
dev-transforms-dynamicppl: developers/transforms/dynamicppl
340 changes: 340 additions & 0 deletions developers/transforms/bijectors/index.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
---
title: "Bijectors in MCMC"
engine: julia
---

```{julia}
#| echo: false
#| output: false
using Pkg;
Pkg.instantiate();
```

All the above has purely been a mathematical discussion of how distributions can be transformed.
Now, we turn to their implementation in Julia, specifically using the [Bijectors.jl package](https://github.com/TuringLang/Bijectors.jl).

## Bijectors.jl

```{julia}
import Random
Random.seed!(468);
using Distributions: Normal, LogNormal, logpdf
using Statistics: mean, var
using Plots: histogram
```

A _bijection_ between two sets ([Wikipedia](https://en.wikipedia.org/wiki/Bijection)) is, essentially, a one-to-one mapping between the elements of these sets.
That is to say, if we have two sets $X$ and $Y$, then a bijection maps each element of $X$ to a unique element of $Y$.
To return to our univariate example, where we transformed $x$ to $y$ using $y = \exp(x)$, the exponentiation function is a bijection because every value of $x$ maps to one unique value of $y$.
The input set (the domain) is $(-\infty, \infty)$, and the output set (the codomain) is $(0, \infty)$.
(Here, $(a, b)$ denotes the open interval from $a$ to $b$ but excluding $a$ and $b$ themselves.)

Since bijections are a one-to-one mapping between elements, we can also reverse the direction of this mapping to create an inverse function.
In the case of $y = \exp(x)$, the inverse function is $x = \log(y)$.

::: {.callout-note}
Technically, the bijections in Bijectors.jl are functions $f: X \to Y$ for which:

- $f$ is continuously differentiable, i.e. the derivative $\mathrm{d}f(x)/\mathrm{d}x$ exists and is continuous (over the domain of interest $X$);
- If $f^{-1}: Y \to X$ is the inverse of $f$, then that is also continuously differentiable (over _its_ own domain, i.e. $Y$).

The technical mathematical term for this is a diffeomorphism ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)), but we call them 'bijectors'.

When thinking about continuous differentiability, it's important to be conscious of the domains or codomains that we care about.
For example, taking the inverse function $\log(y)$ from above, its derivative is $1/y$, which is not continuous at $y = 0$.
However, we specified that the bijection $y = \exp(x)$ maps values of $x \in (-\infty, \infty)$ to $y \in (0, \infty)$, so the point $y = 0$ is not within the domain of the inverse function.
:::

Specifically, one of the primary purposes of Bijectors.jl is to construct _bijections which map constrained distributions to unconstrained ones_.
For example, the log-normal distribution which we saw in [the previous page]({{< meta dev-transforms-distributions >}}) is constrained: its _support_, i.e. the range over which $p(x) > 0$, is $(0, \infty)$.
However, we can transform that to an unconstrained distribution (the normal distribution) using the transformation $y = \log(x)$.

::: {.callout-note}
Bijectors.jl, as well as DynamicPPL (which we'll come to later), can work with a much broader class of bijective transformations of variables, not just ones that go to the entire real line.
But for the purposes of MCMC, unconstraining is the most common transformation, so we'll stick with that terminology.
:::


The `bijector` function, when applied to a distribution, returns a bijection $f$ that can be used to map the constrained distribution to an unconstrained one.
Unsurprisingly, for the log-normal distribution, the bijection is (a broadcasted version of) the $\log$ function.

```{julia}
import Bijectors as B
f = B.bijector(LogNormal())
```

We can apply this transformation to samples from the original distribution, for example:

```{julia}
samples_lognormal = rand(LogNormal(), 5)
samples_normal = f(samples_lognormal)
```

We can also obtain the inverse of a bijection, $f^{-1}$:

```{julia}
f_inv = B.inverse(f)
f_inv(samples_normal) == samples_lognormal
```

We know that the transformation $y = \log(x)$ changes the log-normal distribution to the normal distribution.
Bijectors.jl also gives us a way to access that transformed distribution:

```{julia}
transformed_dist = B.transformed(LogNormal(), f)
```

This type doesn't immediately look like a `Normal()`, but it behaves in exactly the same way.
For example, we can sample from it and plot a histogram:

```{julia}
samples_plot = rand(transformed_dist, 5000)
histogram(samples_plot, bins=50)
```

We can also obtain the logpdf of the transformed distribution and check that it is the same as that of a normal distribution:

```{julia}
println("Sample: $(samples_plot[1])")
println("Expected: $(logpdf(Normal(), samples_plot[1]))")
println("Actual: $(logpdf(transformed_dist, samples_plot[1]))")
```

Given the discussion in the previous sections, you might not be surprised to find that the logpdf of the transformed distribution is implemented using the Jacobian of the transformation.
In particular, it [directly uses](https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/src/Bijectors.jl#L242-L255) the formula

$$\log(q(\mathbf{y})) = \log(p(\mathbf{x})) - \log(|\det(\mathbf{J})|).$$

You can access $\log(|\det(\mathbf{J})|)$ (evaluated at the point $\mathbf{x}$) using the `logabsdetjac` function:

```{julia}
# Reiterating the setup, just to be clear
original_dist = LogNormal()
x = rand(original_dist)
f = B.bijector(original_dist)
y = f(x)
transformed_dist = B.transformed(LogNormal(), f)
println("log(q(y)) : $(logpdf(transformed_dist, y))")
println("log(p(x)) : $(logpdf(original_dist, x))")
println("log(|det(J)|) : $(B.logabsdetjac(f, x))")
```

from which you can see that the equation above holds.
There are more functions available in the Bijectors.jl API; for full details do check out the [documentation](https://turinglang.org/Bijectors.jl/stable/).
For example, `logpdf_with_trans` can directly give us $\log(q(\mathbf{y}))$ without going through the effort of constructing the bijector:

```{julia}
B.logpdf_with_trans(original_dist, x, true)
```

## The case for bijectors in MCMC

Constraints pose a challenge for many numerical methods such as optimisation, and sampling is no exception to this.
The problem is that for any value $x$ outside of the support of a constrained distribution, $p(x)$ will be zero, and the logpdf will be $-\infty$.
Thus, any term that involves some ratio of probabilities (or equivalently, the logpdf) will be infinite.

### Metropolis with rejection

To see the practical impact of this on sampling, let's attempt to sample from a log-normal distribution using a random walk Metropolis algorithm.

One way of handling constraints is to simply reject any steps that would take us out of bounds.
This is a barebones implementation which does precisely that:

```{julia}
# Take a step where the proposal is a normal distribution centred around
# the current value. Return the new value, plus a flag to indicate whether
# the new value was in bounds.
function mh_step(logp, x, in_bounds)
x_proposed = rand(Normal(x, 1))
in_bounds(x_proposed) || return (x, false) # bounds check
acceptance_logp = logp(x_proposed) - logp(x)
return if log(rand()) < acceptance_logp
(x_proposed, true) # successful step
else
(x, true) # failed step
end
end
# Run a random walk Metropolis sampler.
# `logp` : a function that takes `x` and returns the log pdf of the
# distribution we're trying to sample from (up to a constant
# additive factor)
# `n_samples` : the number of samples to draw
# `in_bounds` : a function that takes `x` and returns whether `x` is within
# the support of the distribution
# `x0` : the initial value
# Returns a vector of samples, plus the number of times we went out of bounds.
function mh(logp, n_samples, in_bounds; x0=1.0)
samples = [x0]
x = x0
n_out_of_bounds = 0
for _ in 2:n_samples
x, inb = mh_step(logp, x, in_bounds)
if !inb
n_out_of_bounds += 1
end
push!(samples, x)
end
return (samples, n_out_of_bounds)
end
```

::: {.callout-note}
In the MH algorithm, we technically do not need to explicitly check the proposal, because for any $x \leq 0$, we have that $p(x) = 0$; thus, the acceptance probability will be zero.
However, doing so here allows us to track how often this happens, and also illustrates the general principle of handling constraints by rejection.
:::

Now to actually perform the sampling:

```{julia}
logp(x) = logpdf(LogNormal(), x)
samples, n_out_of_bounds = mh(logp, 10000, x -> x > 0)
histogram(samples, bins=0:0.1:5; xlims=(0, 5))
```

How do we know that this has sampled correctly?
For one, we can check that the mean of the samples are what we expect them to be.
From [Wikipedia](https://en.wikipedia.org/wiki/Log-normal_distribution), the mean of a log-normal distribution is given by $\exp[\mu + (\sigma^2/2)]$.
For our log-normal distribution, we set $\mu = 0$ and $\sigma = 1$, so:

```{julia}
println("expected mean: $(exp(0 + (1^2/2)))")
println(" actual mean: $(mean(samples))")
```

### Metropolis with transformation

The issue with this is that many of the sampling steps are unproductive, in that they bring us to the region of $x \leq 0$ and get rejected:

```{julia}
println("went out of bounds $n_out_of_bounds/10000 times")
```

And this could have been even worse if we had chosen a wider proposal distribution in the Metropolis step, or if the support of the distribution was narrower!
In general, we probably don't want to have to re-parameterise our proposal distribution each time we sample from a distribution with different constraints.

This is where the transformation functions from Bijectors.jl come in: we can use them to map the distribution to an unconstrained one and sample from *that* instead.
Since the sampler only ever sees an unconstrained distribution, it doesn't have to worry about checking for bounds.

To make this happen, instead of passing $\log(p(x))$ to the sampler, we pass $\log(q(y))$.
This can be obtained using the `Bijectors.logpdf_with_trans` function that was introduced above.

```{julia}
d = LogNormal()
f = B.bijector(d) # Transformation function
f_inv = B.inverse(f) # Inverse transformation function
function logq(y)
x = f_inv(y)
return B.logpdf_with_trans(d, x, true)
end
samples_transformed, n_oob_transformed = mh(logq, 10000, x -> true);
```

Now, this process gives us samples that have been transformed, so we need to un-transform them to get the samples from the original distribution:

```{julia}
samples_untransformed = f_inv(samples_transformed)
histogram(samples_untransformed, bins=0:0.1:5; xlims=(0, 5))
```

We can check the mean of the samples too, to see that it is what we expect:

```{julia}
println("expected mean: $(exp(0 + (1^2/2)))")
println(" actual mean: $(mean(samples_untransformed))")
```

On top of that, we can also verify that we don't ever go out of bounds:

```{julia}
println("went out of bounds $n_oob_transformed/10000 times")
```

### Which one is better?

In the subsections above, we've seen two different methods of sampling from a constrained distribution:

1. Sample directly from the distribution and reject any samples outside of its support.
2. Transform the distribution to an unconstrained one and sample from that instead.

(Note that both of these methods are applicable to other samplers as well, such as Hamiltonian Monte Carlo.)

Of course, a natural question to then ask is which one of these is better!

One option might be look at the sample means above to see which one is 'closer' to the expected mean.
However, that's not a very robust method because the sample mean is itself random, and if we were to use a different random seed we might well reach a different conclusion.

Another possibility we could look at the number of times the sample was rejected.
Does a lower rejection rate (as in the transformed case) imply that the method is better?
As it happens, this might seem like an intuitive conclusion, but it's not necessarily the case: for example, the sampling in unconstrained space could be much less efficient, such that even though we're not _rejecting_ samples, the ones that we do get are overly correlated and thus not representative of the distribution.

A robust comparison would involve performing both methods many times and seeing how _reliable_ the sample mean is.

```{julia}
function get_sample_mean(; transform)
if transform
# Sample from transformed distribution
samples = f_inv(first(mh(logq, 10000, x -> true)))
else
# Sample from original distribution and reject if out of bounds
samples = first(mh(logp, 10000, x -> x > 0))
end
return mean(samples)
end
```

```{julia}
means_with_rejection = [get_sample_mean(; transform=false) for _ in 1:1000]
mean(means_with_rejection), var(means_with_rejection)
```

```{julia}
means_with_transformation = [get_sample_mean(; transform=true) for _ in 1:1000]
mean(means_with_transformation), var(means_with_transformation)
```

We can see from this small study that although both methods give us the correct mean (on average), the method with the transformation is more reliable, in that the variance is much lower!

::: {.callout-note}
Alternatively, we could also try to directly measure how correlated the samples are.
One way to do this is to calculate the _effective sample size_ (ESS), which is described in [the Stan documentation](https://mc-stan.org/docs/reference-manual/analysis.html#effective-sample-size.section), and implemented in [MCMCChains.jl](https://github.com/TuringLang/MCMCChains.jl/).
A larger ESS implies that the samples are less correlated, and thus more representative of the underlying distribution:

```{julia}
using MCMCChains: Chains, ess
rejection = first(mh(logp, 10000, x -> x > 0))
transformation = f_inv(first(mh(logq, 10000, x -> true)))
chn = Chains(hcat(rejection, transformation), [:rejection, :transformation])
ess(chn)
```
:::

### What happens without the Jacobian?

In the transformation method above, we used `Bijectors.logpdf_with_trans` to calculate the log probability density of the transformed distribution.
This function makes sure to include the Jacobian term when performing the transformation, and this is what makes sure that when we un-transform the samples, we get the correct distribution.

The next code block shows what happens if we don't include the Jacobian term.
In this `logq_wrong`, we've un-transformed `y` to `x` and calculated the logpdf with respect to its original distribution.
This is exactly the same mistake that we made at the start of this article with `naive_logpdf`.

```{julia}
function logq_wrong(y)
x = f_inv(y)
return logpdf(d, x) # no Jacobian term!
end
samples_questionable, _ = mh(logq_wrong, 100000, x -> true)
samples_questionable_untransformed = f_inv(samples_questionable)
println("mean: $(mean(samples_questionable_untransformed))")
```

You can see that even though we used ten times more samples, the mean is quite wrong, which implies that our samples are not being drawn from the correct distribution.

In the next page, we'll see how to use these transformations in the context of a probabilistic programming language, paying particular attention to their handling in DynamicPPL.
Loading

0 comments on commit 2a44076

Please sign in to comment.