Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable transformations #582

Merged
merged 12 commits into from
Feb 3, 2025
7 changes: 7 additions & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ website:
- developers/compiler/minituring-contexts/index.qmd
- developers/compiler/design-overview/index.qmd

- section: "Variable Transformations"
collapse-level: 1
contents:
- developers/transforms/distributions/index.qmd
- developers/transforms/bijectors/index.qmd
- developers/transforms/dynamicppl/index.qmd

- section: "Inference (note: outdated)"
collapse-level: 1
contents:
Expand Down
311 changes: 311 additions & 0 deletions developers/transforms/bijectors/index.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
---
title: "Bijectors in MCMC"
engine: julia
---

```{julia}
#| echo: false
#| output: false
using Pkg;
Pkg.instantiate();
```

All the above has purely been a mathematical discussion of how distributions can be transformed.
Now, we turn to their implementation in Julia, specifically using the [Bijectors.jl package](https://github.com/TuringLang/Bijectors.jl).

## Bijectors.jl

```{julia}
import Random
Random.seed!(468);

using Distributions: Normal, LogNormal, logpdf
using Statistics: mean, var
using Plots: histogram
```

A _bijection_ between two sets ([Wikipedia](https://en.wikipedia.org/wiki/Bijection)) is, essentially, a one-to-one mapping between the elements of these sets.
That is to say, if we have two sets $X$ and $Y$, then a bijection maps each element of $X$ to a unique element of $Y$.
To return to our univariate example, where we transformed $x$ to $y$ using $y = \exp(x)$, the exponentiation function is a bijection because every value of $x$ maps to one unique value of $y$.
The input set (the domain) is $(-\infty, \infty)$, and the output set (the codomain) is $(0, \infty)$.

Since bijections are a one-to-one mapping between elements, we can also reverse the direction of this mapping to create an inverse function.
In the case of $y = \exp(x)$, the inverse function is $x = \log(y)$.

::: {.callout-note}
Technically, the bijections in Bijectors.jl are functions $f: X \to Y$ for which:

- $f$ is continuously differentiable, i.e. the derivative $\mathrm{d}f(x)/\mathrm{d}x$ exists and is continuous (over the domain of interest $X$);
- If $f^{-1}: Y \to X$ is the inverse of $f$, then that is also continuously differentiable (over _its_ own domain, i.e. $Y$).

The technical mathematical term for this is a diffeomorphism ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)), but we call them 'bijectors'.

When thinking about continuous differentiability, it's important to be conscious of the domains or codomains that we care about.
For example, taking the inverse function $\log(y)$ from above, its derivative is $1/y$, which is not continuous at $y = 0$.
However, we specified that the bijection $y = \exp(x)$ maps values of $x \in (-\infty, \infty)$ to $y \in (0, \infty)$, so the point $y = 0$ is not within the domain of the inverse function.
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
:::

Specifically, one of the primary purposes of Bijectors.jl is used to construct _bijections which map constrained distributions to unconstrained ones_.
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
For example, the log-normal distribution which we saw above is constrained: its _support_, i.e. the range over which $p(x) > 0$, is $(0, \infty)$.
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
However, we can transform that to an unconstrained distribution (the normal distribution) using the transformation $y = \log(x)$.

::: {.callout-note}
Bijectors.jl, as well as DynamicPPL (which we'll come to later), can work with a much broader class of bijective transformations of variables, not just ones that go to the entire real line.
But for the purposes of MCMC, unconstraining is the most common transformation, so we'll stick with that terminology.
:::


The `bijector` function, when applied to a distribution, returns a bijection $f$ that can be used to map the constrained distribution to an unconstrained one.
Unsurprisingly, for the log-normal distribution, the bijection is (a broadcasted version of) the $\log$ function.

```{julia}
import Bijectors as B

f = B.bijector(LogNormal())
```

We can apply this transformation to samples from the original distribution, for example:

```{julia}
samples_lognormal = rand(LogNormal(), 5)

samples_normal = f(samples_lognormal)
```

We can also obtain the inverse of a bijection, $f^{-1}$:

```{julia}
f_inv = B.inverse(f)

f_inv(samples_normal) == samples_lognormal
```

We know that the transformation $y = \log(x)$ changes the log-normal distribution to the normal distribution.
Bijectors.jl also gives us a way to access that transformed distribution:

```{julia}
transformed_dist = B.transformed(LogNormal(), f)
```

This type doesn't immediately look like a `Normal()`, but it behaves in exactly the same way.
For example, we can sample from it and plot a histogram:

```{julia}
samples_plot = rand(transformed_dist, 5000)
histogram(samples_plot, bins=50)
```

We can also obtain the logpdf of the transformed distribution and check that it is the same as that of a normal distribution:

```{julia}
println("Sample: $(samples_plot[1])")
println("Expected: $(logpdf(Normal(), samples_plot[1]))")
println("Actual: $(logpdf(transformed_dist, samples_plot[1]))")
```

Given the discussion in the previous sections, you might not be surprised to find that the logpdf of the transformed distribution is implemented using the Jacobian of the transformation.
In particular, it [directly uses](https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/src/Bijectors.jl#L242-L255) the formula

$$\log(q(\mathbf{y})) = \log(p(\mathbf{x})) - \log(|\det(\mathbf{J})|).$$

You can access $\log(|\det(\mathbf{J})|)$ (evaluated at the point $\mathbf{x}$) using the `logabsdetjac` function:

```{julia}
# Reiterating the setup, just to be clear
original_dist = LogNormal()
x = rand(original_dist)
f = B.bijector(original_dist)
y = f(x)
transformed_dist = B.transformed(LogNormal(), f)

println("log(q(y)) : $(logpdf(transformed_dist, y))")
println("log(p(x)) : $(logpdf(original_dist, x))")
println("log(|det(J)|) : $(B.logabsdetjac(f, x))")
```

from which you can see that the equation above holds.
There are more functions available in the Bijectors.jl API; for full details do check out the [documentation](https://turinglang.org/Bijectors.jl/stable/).
For example, `logpdf_with_trans` can directly give us $\log(q(\mathbf{y}))$ without going through the effort of constructing the bijector:

```{julia}
B.logpdf_with_trans(original_dist, x, true)
```

## The case for bijectors in MCMC

Constraints pose a challenge for many numerical methods such as optimisation, and sampling is no exception to this.
The problem is that for any value $x$ outside of the support of a constrained distribution, $p(x)$ will be zero, and the logpdf will be $-\infty$.
Thus, any term that involves some ratio of probabilities (or equivalently, the logpdf) will be infinite.

### Metropolis with rejection

To see the practical impact of this on sampling, let's attempt to sample from a log-normal distribution using a random walk Metropolis algorithm.

One way of handling constraints is to simply reject any steps that would take us out of bounds.
This is a barebones implementation which does precisely that:

```{julia}
# Take a step where the proposal is a normal distribution centred around
# the current value. Return the new value, plus a flag to indicate whether
# the new value was in bounds.
function mh_step(logp, x, in_bounds)
x_proposed = rand(Normal(x, 1))
in_bounds(x_proposed) || return (x, false) # bounds check
acceptance_logp = logp(x_proposed) - logp(x)
return if log(rand()) < acceptance_logp
(x_proposed, true) # successful step
else
(x, true) # failed step
end
end

# Run a random walk Metropolis sampler.
# `logp` : a function that takes `x` and returns the log pdf of the
# distribution we're trying to sample from (up to a constant
# additive factor)
# `n_samples` : the number of samples to draw
# `in_bounds` : a function that takes `x` and returns whether `x` is within
# the support of the distribution
# `x0` : the initial value
# Returns a vector of samples, plus the number of times we went out of bounds.
function mh(logp, n_samples, in_bounds; x0=1.0)
samples = []
x = x0
n_out_of_bounds = 0
for _ in 2:n_samples
x, inb = mh_step(logp, x, in_bounds)
if !inb
n_out_of_bounds += 1
end
push!(samples, x)
end
return (samples, n_out_of_bounds)
end
```

::: {.callout-note}
In the MH algorithm, we technically do not need to explicitly check the proposal, because for any $x \leq 0$, we have that $p(x) = 0$; thus, the acceptance probability will be zero.
However, doing so here allows us to track how often this happens, and also illustrates the general principle of handling constraints by rejection.
:::

Now to actually perform the sampling:

```{julia}
logp(x) = logpdf(LogNormal(), x)
samples, n_out_of_bounds = mh(logp, 10000, x -> x > 0)
histogram(samples, bins=0:0.1:5; xlims=(0, 5))
```

How do we know that this has sampled correctly?
For one, we can check that the mean of the samples are what we expect them to be.
From [Wikipedia](https://en.wikipedia.org/wiki/Log-normal_distribution), the mean of a log-normal distribution is given by $\exp[\mu + (\sigma^2/2)]$.
For our log-normal distribution, we set $\mu = 0$ and $\sigma = 1$, so:

```{julia}
println("expected mean: $(exp(0 + (1^2/2)))")
println(" actual mean: $(mean(samples))")
```

### Metropolis with transformation

The issue with this is that many of the sampling steps are unproductive, in that they bring us to the region of $x \leq 0$ and get rejected:

```{julia}
println("went out of bounds $n_out_of_bounds/10000 times")
```

And this could have been even worse if we had chosen a wider proposal distribution in the Metropolis step, or if the support of the distribution was narrower!
In general, we probably don't want to have to re-parameterise our proposal distribution each time we sample from a distribution with different constraints.

This is where the transformation functions from Bijectors.jl come in: we can use them to map the distribution to an unconstrained one and sample from *that* instead.
Since the sampler only ever sees an unconstrained distribution, it doesn't have to worry about checking for bounds.

To make this happen, instead of passing $\log(p(x))$ to the sampler, we pass $\log(q(y))$.
This can be obtained using the `Bijectors.logpdf_with_trans` function that was introduced above.

```{julia}
d = LogNormal()
f = B.bijector(d) # Transformation function
f_inv = B.inverse(f) # Inverse transformation function
function logq(y)
x = f_inv(y)
return B.logpdf_with_trans(d, x, true)
end
samples_transformed, _ = mh(logq, 10000, x -> true);
```

Now, this process gives us samples that have been transformed, so we need to un-transform them to get the samples from the original distribution:

```{julia}
samples_untransformed = f_inv(samples_transformed)
histogram(samples_untransformed, bins=0:0.1:5; xlims=(0, 5))
```

We can check the mean of the samples too, to see that it is what we expect:

```{julia}
println("expected mean: $(exp(0 + (1^2/2)))")
println(" actual mean: $(mean(samples_untransformed))")
```

penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
### Which one is better?

In the subsections above, we've seen two different methods of sampling from a constrained distribution:

1. Sample directly from the distribution and reject any samples outside of its support.
2. Transform the distribution to an unconstrained one and sample from that instead.

(Note that both of these methods are applicable to other samplers as well, such as Hamiltonian Monte Carlo.)

Of course, the question then becomes which one of these is better.
We might look at the sample means above to see which one is 'closer' to the expected mean, but that's not a very robust method because the sample mean is itself random.
What we could do is to perform both methods many times and see how reliable the sample mean is.

```{julia}
function get_sample_mean(; transform)
if transform
# Sample from transformed distribution
samples = f_inv(first(mh(logq, 10000, x -> true)))
else
# Sample from original distribution and reject if out of bounds
samples = first(mh(logp, 10000, x -> x > 0))
end
return mean(samples)
end
```

```{julia}
means_with_rejection = [get_sample_mean(; transform=false) for _ in 1:1000]
mean(means_with_rejection), var(means_with_rejection)
```

```{julia}
means_with_transformation = [get_sample_mean(; transform=true) for _ in 1:1000]
mean(means_with_transformation), var(means_with_transformation)
```

We can see from this small study that although both methods give us the correct mean (on average), the method with the transformation is more reliable, in that the variance is much lower!

### What happens without the Jacobian?

In the transformation method above, we used `Bijectors.logpdf_with_trans` to calculate the log probability density of the transformed distribution.
This function makes sure to include the Jacobian term when performing the transformation, and this is what makes sure that when we un-transform the samples, we get the correct distribution.

The next code block shows what happens if we don't include the Jacobian term.
In this `logq_wrong`, we've un-transformed `y` to `x` and calculated the logpdf with respect to its original distribution.
This is exactly the same mistake that we made at the start of this article with `naive_logpdf`.

```{julia}
function logq_wrong(y)
x = f_inv(y)
return logpdf(d, x) # no Jacobian term!
end
samples_questionable, _ = mh(logq_wrong, 100000, x -> true)
samples_questionable_untransformed = f_inv(samples_questionable)

println("mean: $(mean(samples_questionable_untransformed))")
```

You can see that even though we used ten times more samples, the mean is quite wrong, which implies that our samples are not being drawn from the correct distribution.

In the next page, we'll see how to use these transformations in the context of a probabilistic programming language, paying particular attention to their handling in DynamicPPL.
Loading