-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPLVM tutorial #122
Merged
Merged
GPLVM tutorial #122
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
d8bce1c
initial model - advi issue
leachim f41cf3a
reproducible example for bug
leachim 8256a94
issue with sparsity
leachim 3debcec
basic flow of tutorial; requires clean-up and more text
leachim 4157c64
mu
leachim 1bb2e55
update dataset
leachim 0e11fa3
change problem size for speed
leachim 71e4d8f
updates to inference, speedup
leachim a88c5c5
debug build environment
leachim e9b1a72
working version
leachim 1dcf2f9
mu
leachim b7c32c6
mu
leachim 52fe171
mu
leachim a39bc31
Apply suggestions from code review
leachim d359c6a
Delete 12_gaussian-process-latent-variable-model.md
yebai 6ecfe69
Remove generated markdown files.
yebai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
306 changes: 306 additions & 0 deletions
306
...s/12-gaussian-process-latent-variable-model/12_gaussian-process-latent-variable-model.jmd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
--- | ||
title: Gaussian Process Latent Variable Model | ||
permalink: /:collection/:name/ | ||
redirect_from: tutorials/12-gaussian-process-latent-variable-model/ | ||
--- | ||
|
||
# Gaussian Process Latent Variable Model | ||
|
||
In a previous tutorial, we have discussed latent variable models, in particular probabilistic principal component analysis (pPCA). | ||
Here, we show how we can extend the mapping provided by pPCA to non-linear mappings between input and output. | ||
For more details about the Gaussian Process Latent Variable Model (GPLVM), | ||
we refer the reader to the [original publication](https://jmlr.org/papers/v6/lawrence05a.html) and a [further extension](http://proceedings.mlr.press/v9/titsias10a/titsias10a.pdf). | ||
|
||
In short, the GPVLM is a dimensionality reduction technique that allows us to embed a high-dimensional dataset in a lower-dimensional embedding. | ||
Importantly, it provides the advantage that the linear mappings from the embedded space can be non-linearised through the use of Gaussian Processes. | ||
|
||
Let's start by loading some dependencies. | ||
```julia | ||
using Turing | ||
using AbstractGPs, KernelFunctions, Random, Plots | ||
|
||
using LinearAlgebra | ||
using VegaLite, DataFrames, StatsPlots, StatsBase | ||
using RDatasets | ||
|
||
Random.seed!(1789); | ||
``` | ||
|
||
We demonstrate the GPLVM with a very small dataset: [Fisher's Iris data set](https://en.wikipedia.org/wiki/Iris_flower_data_set). | ||
This is mostly for reasons of run time, so the tutorial can be run quickly. | ||
As you will see, one of the major drawbacks of using GPs is their speed, | ||
although this is an active area of research. | ||
We will briefly touch on some ways to speed things up at the end of this tutorial. | ||
We transform the original data with non-linear operations in order to demonstrate the power of GPs to work on non-linear relationships, while keeping the problem reasonably small. | ||
|
||
```julia | ||
data = dataset("datasets", "iris") | ||
species = data[!, "Species"] | ||
index = shuffle(1:150) | ||
# we extract the four measured quantities, | ||
# so the dimension of the data is only d=4 for this toy example | ||
dat = Matrix(data[index, 1:4]) | ||
labels = data[index, "Species"] | ||
|
||
# non-linearize data to demonstrate ability of GPs to deal with non-linearity | ||
dat[:, 1] = 0.5 * dat[:, 1].^2 + 0.1 * dat[:,1].^3 | ||
dat[:, 2] = dat[:, 2].^3 + 0.2 * dat[:,2].^4 | ||
dat[:, 3] = 0.1 * exp.(dat[:, 3]) - 0.2 * dat[:,3].^2 | ||
dat[:, 4] = 0.5 * log.(dat[:, 4]).^2 + 0.01 * dat[:,3].^5 | ||
|
||
|
||
# normalize data | ||
leachim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dt = fit(ZScoreTransform, dat, dims=1); | ||
StatsBase.transform!(dt, dat); | ||
``` | ||
|
||
We will start out by demonstrating the basic similarity between pPCA (see the tutorial on this topic) and the GPLVM model. | ||
Indeed, pPCA is basically equivalent to running the GPLVM model with an automatic relevance determination (ARD) linear kernel. | ||
|
||
First, we re-introduce the pPCA model (see the tutorial on pPCA for details) | ||
```julia | ||
@model function pPCA(x, ::Type{TV} = Array{Float64}) where {TV} | ||
# Dimensionality of the problem. | ||
N, D = size(x) | ||
# latent variable z | ||
z ~ filldist(Normal(), D, N) | ||
# weights/loadings W | ||
w ~ filldist(Normal(), D, D) | ||
mu = (w * z)' | ||
for d in 1:D | ||
x[:, d] ~ MvNormal(mu[:, d], I) | ||
end | ||
end; | ||
``` | ||
|
||
|
||
We define two different kernels, a simple linear kernel with an Automatic Relevance Determination transform and a | ||
squared exponential kernel. | ||
```julia | ||
linear_kernel(α) = LinearKernel() ∘ ARDTransform(α) | ||
sekernel(α, σ) = σ * SqExponentialKernel() ∘ ARDTransform(α); | ||
``` | ||
And here is the GPLVM model. | ||
We create separate models for the two types of kernel. | ||
```julia | ||
@model function GPLVM_linear(Y, K=4,::Type{T} = Float64) where {T} | ||
|
||
# Dimensionality of the problem. | ||
N, D = size(Y) | ||
# K is the dimension of the latent space | ||
@assert K <= D | ||
noise = 1e-3 | ||
|
||
# Priors | ||
α ~ MvLogNormal(MvNormal(zeros(K), I)) | ||
Z ~ filldist(Normal(), K, N) | ||
mu ~ filldist(Normal(), N) | ||
|
||
kernel = linear_kernel(α) | ||
|
||
gp = GP(mu, kernel) | ||
cv = cov(gp(ColVecs(Z), noise)) | ||
Y ~ filldist(MvNormal(mu, cv), D) | ||
end; | ||
|
||
@model function GPLVM(Y, K=4,::Type{T} = Float64) where {T} | ||
|
||
# Dimensionality of the problem. | ||
N, D = size(Y) | ||
# K is the dimension of the latent space | ||
@assert K <= D | ||
noise = 1e-3 | ||
|
||
# Priors | ||
α ~ MvLogNormal(MvNormal(zeros(K), I)) | ||
σ ~ LogNormal(0.0, 1.0) | ||
Z ~ filldist(Normal(), K, N) | ||
mu ~ filldist(Normal(), N) | ||
|
||
kernel = sekernel(α, σ) | ||
|
||
gp = GP(mu, kernel) | ||
cv = cov(gp(ColVecs(Z), noise)) | ||
Y ~ filldist(MvNormal(mu, cv), D) | ||
end; | ||
``` | ||
|
||
```julia | ||
# Standard GPs don't scale very well in n, so we use a small subsample for the purpose of this tutorial | ||
n_data = 40 | ||
# number of features to use from dataset | ||
n_features = 4 | ||
# latent dimension for GP case | ||
ndim = 4; | ||
``` | ||
|
||
```julia | ||
ppca = pPCA(dat[1:n_data, 1:n_features]) | ||
chain_ppca = sample(ppca, NUTS(), 1000); | ||
``` | ||
|
||
```julia | ||
# we extract the posterior mean estimates of the parameters from the chain | ||
w = reshape(mean(group(chain_ppca, :w))[:, 2], (n_features,n_features)) | ||
z = reshape(mean(group(chain_ppca, :z))[:, 2], (n_features, n_data)) | ||
X = w * z | ||
|
||
df_pre = DataFrame(z', :auto) | ||
rename!(df_pre, Symbol.( ["z"*string(i) for i in collect(1:n_features)])) | ||
df_pre[!,:type] = labels[1:n_data] | ||
p_ppca = df_pre |> @vlplot(:point, x=:z1, y=:z2, color="type:n") | ||
``` | ||
|
||
We can see that the pPCA fails to distinguish the groups. | ||
In particular, the `setosa` species is not clearly separated from `versicolor` and `virginica`. | ||
This is due to the non-linearities that we introduced, as without them the two groups can be clearly distinguished | ||
using pPCA (see the pPCA tutorial). | ||
|
||
Let's try the same with our linear kernel GPLVM model. | ||
```julia | ||
gplvm_linear = GPLVM_linear(dat[1:n_data, 1:n_features], ndim) | ||
|
||
chain_linear = sample(gplvm_linear, NUTS(), 500) | ||
# we extract the posterior mean estimates of the parameters from the chain | ||
z_mean = reshape(mean(group(chain_linear, :Z))[:, 2], (n_features, n_data)) | ||
alpha_mean = mean(group(chain_linear, :α))[:, 2] | ||
``` | ||
|
||
```julia | ||
df_gplvm_linear = DataFrame(z_mean', :auto) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a small clarification for this block? |
||
rename!(df_gplvm_linear, Symbol.( ["z"*string(i) for i in collect(1:ndim)])) | ||
df_gplvm_linear[!,:sample] = 1:n_data | ||
df_gplvm_linear[!,:labels] = labels[1:n_data] | ||
alpha_indices = sortperm(alpha_mean, rev=true)[1:2] | ||
println(alpha_indices) | ||
df_gplvm_linear[!,:ard1] = z_mean[alpha_indices[1], :] | ||
df_gplvm_linear[!,:ard2] = z_mean[alpha_indices[2], :] | ||
|
||
p_linear = df_gplvm_linear |> @vlplot(:point, x=:ard1, y=:ard2, color="labels:n") | ||
p_linear | ||
``` | ||
We can see that similar to the pPCA case, the linear kernel GPLVM fails to distinguish between the two groups | ||
(`setosa` on the one hand, and `virginica` and `verticolor` on the other). | ||
|
||
Finally, we demonstrate that by changing the kernel to a non-linear function, we are able to separate the data again. | ||
```julia | ||
gplvm = GPLVM(dat[1:n_data, 1:n_features], ndim) | ||
|
||
chain_gplvm = sample(gplvm, NUTS(), 500) | ||
# we extract the posterior mean estimates of the parameters from the chain | ||
z_mean = reshape(mean(group(chain_gplvm, :Z))[:, 2], (ndim, n_data)) | ||
alpha_mean = mean(group(chain_gplvm, :α))[:, 2] | ||
``` | ||
|
||
```julia | ||
df_gplvm = DataFrame(z_mean', :auto) | ||
rename!(df_gplvm, Symbol.( ["z"*string(i) for i in collect(1:ndim)])) | ||
df_gplvm[!,:sample] = 1:n_data | ||
df_gplvm[!,:labels] = labels[1:n_data] | ||
alpha_indices = sortperm(alpha_mean, rev=true)[1:2] | ||
println(alpha_indices) | ||
df_gplvm[!,:ard1] = z_mean[alpha_indices[1], :] | ||
df_gplvm[!,:ard2] = z_mean[alpha_indices[2], :] | ||
|
||
p_gplvm = df_gplvm |> @vlplot(:point, x=:ard1, y=:ard2, color="labels:n") | ||
p_gplvm | ||
``` | ||
|
||
|
||
```julia; echo=false | ||
let | ||
@assert abs(mean(z_mean[alpha_indices[1], labels[1:n_data] .== "setosa"]) - mean(z_mean[alpha_indices[1], labels[1:n_data] .!= "setosa"])) > 1.4 | ||
end | ||
``` | ||
|
||
Now, the split between the two groups is visible again. | ||
|
||
### Speeding up inference | ||
|
||
Gaussian processes tend to be slow, as they naively require operations in the order of $O(n^3)$. | ||
Here, we demonstrate a simple speedup using the Stheno library. | ||
Speeding up Gaussian process inference is an active area of research. | ||
|
||
```julia | ||
using Stheno | ||
@model function GPLVM_sparse(Y, K,::Type{T} = Float64) where {T} | ||
|
||
# Dimensionality of the problem. | ||
N, D = size(Y) | ||
# dimension of latent space | ||
@assert K <= D | ||
# number of inducing points | ||
n_inducing = 25 | ||
noise = 1e-3 | ||
|
||
# Priors | ||
α ~ MvLogNormal(MvNormal(zeros(K), I)) | ||
σ ~ LogNormal(1.0, 1.0) | ||
Z ~ filldist(Normal(), K, N) | ||
mu ~ filldist(Normal(), N) | ||
|
||
kernel = σ * SqExponentialKernel() ∘ ARDTransform(α) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the function from above? |
||
|
||
## Standard | ||
# gpc = GPC() | ||
# f = Stheno.wrap(GP(kernel), gpc) | ||
# gp = f(ColVecs(Z), noise) | ||
# Y ~ filldist(gp, D) | ||
|
||
## SPARSE GP | ||
# xu = reshape(repeat(locations, K), :, K) # inducing points | ||
# xu = reshape(repeat(collect(range(-2.0, 2.0; length=20)), K), :, K) # inducing points | ||
lbound = minimum(Y) + 1e-6 | ||
ubound = maximum(Y) - 1e-6 | ||
# locations ~ filldist(Uniform(lbound, ubound), n_inducing) | ||
# locations = [-2., -1.5 -1., -0.5, -0.25, 0.25, 0.5, 1., 2.] | ||
# locations = collect(LinRange(lbound, ubound, n_inducing)) | ||
locations = quantile(vec(Y), LinRange(0.01, 0.99, n_inducing)) | ||
xu = reshape(locations, 1, :) | ||
gp = Stheno.wrap(GP(kernel), GPC()) | ||
fobs = gp(ColVecs(Z), noise) | ||
finducing = gp(xu, 1e-12) | ||
sfgp = SparseFiniteGP(fobs, finducing) | ||
cv = cov(sfgp.fobs) | ||
Y ~ filldist(MvNormal(mu, cv), D) | ||
|
||
end | ||
``` | ||
|
||
```julia | ||
n_data = 50 | ||
gplvm_sparse = GPLVM_sparse(dat[1:n_data, :], ndim) | ||
|
||
chain_gplvm_sparse = sample(gplvm_sparse, NUTS(), 500) | ||
# we extract the posterior mean estimates of the parameters from the chain | ||
z_mean = reshape(mean(group(chain_gplvm_sparse, :Z))[:, 2], (ndim, n_data)) | ||
alpha_mean = mean(group(chain_gplvm_sparse, :α))[:, 2] | ||
``` | ||
|
||
```julia | ||
df_gplvm_sparse = DataFrame(z_mean', :auto) | ||
rename!(df_gplvm_sparse, Symbol.( ["z"*string(i) for i in collect(1:ndim)])) | ||
df_gplvm_sparse[!,:sample] = 1:n_data | ||
df_gplvm_sparse[!,:labels] = labels[1:n_data] | ||
alpha_indices = sortperm(alpha_mean, rev=true)[1:2] | ||
df_gplvm_sparse[!,:ard1] = z_mean[alpha_indices[1], :] | ||
df_gplvm_sparse[!,:ard2] = z_mean[alpha_indices[2], :] | ||
p_sparse = df_gplvm_sparse |> @vlplot(:point, x=:ard1, y=:ard2, color="labels:n") | ||
p_sparse | ||
``` | ||
|
||
Comparing the runtime, between the two versions, we can observe a clear speed-up with the sparse version. | ||
|
||
```julia; echo=false | ||
let | ||
@assert abs(mean(z_mean[alpha_indices[1], labels[1:n_data] .== "setosa"]) - mean(z_mean[alpha_indices[1], labels[1:n_data] .!= "setosa"])) > 1.0 | ||
|
||
@assert abs(mean(z_mean[alpha_indices[2], labels[1:n_data] .== "setosa"]) - mean(z_mean[alpha_indices[2], labels[1:n_data] .!= "setosa"])) > 0.1 | ||
end | ||
``` | ||
|
||
```julia, echo=false, skip="notebook" | ||
if isdefined(Main, :TuringTutorials) | ||
Main.TuringTutorials.tutorial_footer(WEAVE_ARGS[:folder], WEAVE_ARGS[:file]) | ||
end | ||
``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason for using both StatsPlots and VegaLite? Seems a bit complicated to use two different plotting frameworks in one tutorial.