Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General improvements and fixes #133

Merged
merged 52 commits into from
Dec 4, 2022
Merged
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4d48a35
additional deps and test deps
torfjelde Sep 11, 2021
2df4a2b
updated stepping code to use AbstractSwapStrategy and made TemperedSa…
torfjelde Oct 4, 2021
83c5c49
made TemperedSampler concretely typed and fixed soem docs
torfjelde Oct 4, 2021
9e9153c
introduced AbstractSwapStrategy and removed get_params and make_tempe…
torfjelde Oct 4, 2021
26f12f1
updated stepping.jl
torfjelde Oct 4, 2021
aa88ae4
added docstring for make_tempered_model
torfjelde Oct 4, 2021
2523a9d
updated adaptation.jl and made structs concrete
torfjelde Oct 4, 2021
7462ebf
updated ladders.jl
torfjelde Oct 4, 2021
3fb13e9
addressed some comments
torfjelde Oct 13, 2021
90e60f2
added tests
torfjelde Oct 13, 2021
e30718d
fixed a bug
torfjelde Oct 13, 2021
9e8607a
updated the StateHistoryCallback a bit
torfjelde Oct 20, 2021
3bb79df
made the distinction between chains and processes clearer
torfjelde Oct 20, 2021
90722c9
added tests
torfjelde Oct 20, 2021
50ad1ec
fixed incorrect statement
torfjelde Oct 20, 2021
0c204ac
renamed some fields to be more descriptive and fixed left-over bug
torfjelde Oct 20, 2021
e2bbc90
updated tests
torfjelde Oct 20, 2021
e7466cc
removed some show from tests
torfjelde Oct 20, 2021
0a59517
began updating docstrings
torfjelde Oct 20, 2021
f7a7f31
fixed docstring for TemeperedState
torfjelde Oct 21, 2021
696d8d1
fix exports
torfjelde Oct 21, 2021
bbb2fc2
a bunch of renaming
torfjelde Oct 21, 2021
8ac7374
deleted plotting functionality
torfjelde Oct 21, 2021
f7c46e7
fixed bug and added should_swap method
torfjelde Oct 21, 2021
dcb2a0d
improved tests
torfjelde Oct 21, 2021
287b501
Typo
ParadaCarleton Nov 20, 2021
6239710
Typo
ParadaCarleton Nov 20, 2021
8c1b8ff
implemented adaptation scheme for inverse temperatures using a geomet…
torfjelde Dec 7, 2021
f70690f
made some changes to some code that I cannot understand the original …
torfjelde Dec 7, 2021
afa2900
added parameter for controlling which type of schedule to use when ad…
torfjelde Dec 7, 2021
c0d9e61
make number of steps taken for each adaptor part of their state
torfjelde Dec 16, 2021
4563d33
improvements to parameterization of the adaptation techniques
torfjelde Nov 14, 2022
86fbf0d
updated test env
torfjelde Nov 14, 2022
229059b
keep track of swapping ratios
torfjelde Nov 14, 2022
c4583f9
tests are now runnable
torfjelde Nov 14, 2022
a1efd11
commented out unused code
torfjelde Nov 14, 2022
c221572
Corrected typo
HarrisonWilde Nov 16, 2022
632cca9
Added 1D GMM, sort of works for it
HarrisonWilde Nov 16, 2022
f30acfa
Make `StandardSwap` the default strategy when one
HarrisonWilde Nov 16, 2022
0a0b131
Fixing test case for GMM
HarrisonWilde Nov 16, 2022
910dc6d
Implementing burn-in, introduces depedency on StatsBase
HarrisonWilde Nov 16, 2022
a4437f9
Fixed error with burnin
HarrisonWilde Nov 16, 2022
b9c6a90
cleaning up working_code
HarrisonWilde Nov 16, 2022
79cbbf0
QoL improvements on the code
HarrisonWilde Nov 16, 2022
4336516
Removing `StatsBase` dependency and `discard_initial` override
HarrisonWilde Nov 16, 2022
915b610
Adding back accidentally deleted RandomPermutationSwap stuff
HarrisonWilde Nov 16, 2022
4898193
cleaned up testing a bit
torfjelde Nov 17, 2022
13cb3b2
made the compute_tempered_logdensities a bit more general
torfjelde Nov 17, 2022
cb6937e
Implementing `tempered_sample` to allow for no-swap burn-in and easie…
HarrisonWilde Nov 17, 2022
db4b842
Tweaking sample call
HarrisonWilde Nov 17, 2022
fce26aa
Working burnin
HarrisonWilde Nov 17, 2022
8a6b47e
Merge pull request #137 from TuringLang/harry/improvements_additions
yebai Dec 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 77 additions & 44 deletions src/swapping.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,50 @@
"""
AbstractSwapStrategy

Represents a strategy for swapping between parallel chains.

A concrete subtype is expected to implement the method [`swap_step`](@ref).
"""
abstract type AbstractSwapStrategy end

"""
StandardSwap <: AbstractSwapStrategy

At every swap step taken, this strategy samples a single chain index `i` and proposes
a swap between chains `i` and `i + 1`.

This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05]

The sampling of the chain index ensures reversibility/detailed balance is satisfied.

# References
[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005).
"""
struct StandardSwap <: AbstractSwapStrategy end

"""
RandomPermutationSwap <: AbstractSwapStrategy

At every swap step taken, this strategy randomly shuffles all the chain indices
and then iterates through them, proposing swaps for neighboring chains.

The shuffling of chain indices ensures reversibility/detailed balance is satisfied.
"""
struct RandomPermutationSwap <: AbstractSwapStrategy end


"""
NonReversibleSwap <: AbstractSwapStrategy

At every swap step taken, this strategy _deterministically_ traverses first the
odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step
taken traverses even chain indices, proposing swaps between neighbors.

Note that this method is _not_ reversible, and does not satisfy detailed balance.
As a result, this method is asymptotically biased.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, not sure if this is the right place to put this, but this statement is not true. It is not reversible, but it and preserves the correct distribution. This non-reversibility is actually the main reason why doing something like this swapping strategy is better. The non-reversibility prevents the swapping from devolving into a diffusion-type process. The Syed 2019 paper goes into this in a lot of detail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptiede Do you have a link to that paper?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry! I should have included it. Here is the arXiv link https://arxiv.org/abs/1905.02939

A follow-up paper https://arxiv.org/abs/2102.07720 also goes into detail for how to further optimize PT. The primary author of that paper, @s-syed, and I were actually hoping to implement something like this into this package so if this repo is picking up again, we would love to help!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I am more than happy to answer any questions. This odd-even scheme is actually extremely important and in some sense achieves the optimal performance and should always be used over the reversible counterpart. Furthermore, the tuning guidelines are very different and results in significantly different tuning guidelines.

In practice we find a 10-100x boost in performance compared to reversible swapping schemes and their corresponding tuning guidelines that have been implemented so far.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, not sure if this is the right place to put this, but this statement is not true. It is not reversible, but it and preserves the correct distribution.

Ah, thank you for pointing this out! This is indeed an incorrect statement.

And thank you for pointing me to those papers! Just read the first one, and it's real neat stuff. And we'd be happy to collaborate on this package. Personally I'm just starting to get into PT samplers, so it would be awesome to include someone with both practical and theoretical experience with these things.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I'd love to help. I have some experience implementing PT samplers for use with HPC and clusters, so I would be thrilled to help in any way I can! My old code was written in C++, so moving to Julia would be great.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!:) I'll push some changes I've accumulated locally very soon.

"""
struct NonReversibleSwap <: AbstractSwapStrategy end

"""
swap_betas(chain_index, k)

Expand All @@ -9,49 +56,24 @@ function swap_betas(chain_index, k)
return sortperm(chain_index), chain_index
end

function make_tempered_loglikelihood end
function get_params end


"""
get_tempered_loglikelihoods_and_params(model, sampler, states, k, Δ, chain_index)
compute_tempered_logdensities(model, sampler, transition, transition_other, β)

Temper the `model`'s density using the `k`th and `k + 1`th temperatures
selected via `Δ` and `chain_index`. Then retrieve the parameters using the chains'
current transitions extracted from the collection of `states`.
Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the
log-density for `model` with inverse-temperature `β`.
"""
function get_tempered_loglikelihoods_and_params(
model,
sampler::AbstractMCMC.AbstractSampler,
states,
k::Integer,
Δ::Vector{Real},
chain_index::Vector{<:Integer}
)

logπk = make_tempered_loglikelihood(model, Δ[k])
logπkp1 = make_tempered_loglikelihood(model, Δ[k + 1])

θk = get_params(states[chain_index[k]][1])
θkp1 = get_params(states[chain_index[k + 1]][1])

return logπk, logπkp1, θk, θkp1
end

function compute_tempered_logdensities end

"""
swap_acceptance_pt(logπk, logπkp1, θk, θkp1)
swap_acceptance_pt(logπk, logπkp1)

Calculates and returns the swap acceptance ratio for swapping the temperature
of two chains. Using tempered likelihoods `logπk` and `logπkp1` at the chains'
current state parameters `θk` and `θkp1`.
current state parameters.
"""
function swap_acceptance_pt(logπk, logπkp1, θk, θkp1)
return min(
1,
exp(logπkp1(θk) + logπk(θkp1)) / exp(logπk(θk) + logπkp1(θkp1))
# exp(abs(βk - βkp1) * abs(AdvancedMH.logdensity(model, samplek) - AdvancedMH.logdensity(model, samplekp1)))
)
function swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1)
return (logπkp1_θk + logπk_θkp1) - (logπk_θk + logπkp1_θkp1)
end


Expand All @@ -61,21 +83,32 @@ end
Attempt to swap the temperatures of two chains by tempering the densities and
calculating the swap acceptance ratio; then swapping if it is accepted.
"""
function swap_attempt(model, sampler, ts, k, adapt, n)

logπk, logπkp1, θk, θkp1 = get_tempered_loglikelihoods_and_params(model, sampler, ts.states, k, ts.Δ, ts.chain_index)
function swap_attempt(rng, model, sampler, ts, k, adapt, n)
# Extract the relevant transitions.
transitionk = first(ts.states[ts.chain_index[k]])
transitionkp1 = first(ts.states[ts.chain_index[k + 1]])
# Evaluate logdensity for both parameters for each tempered density.
logπk_θk, logπk_θkp1 = compute_tempered_logdensities(
model, sampler, transitionk, transitionkp1, ts.Δ[k]
)
logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities(
model, sampler, transitionkp1, transitionk, ts.Δ[k + 1]
)

swap_ar = swap_acceptance_pt(logπk, logπkp1, θk, θkp1)
U = rand(Distributions.Uniform(0, 1))

# If the proposed temperature swap is accepted according to swap_ar and U, swap the temperatures for future steps
if U ≤ swap_ar
ts.Δ_index, ts.chain_index = swap_betas(ts.chain_index, k)
# If the proposed temperature swap is accepted according `logα`,
# swap the temperatures for future steps.
logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1)
if -Random.randexp(rng) ≤ logα
Δ_index, chain_index = swap_betas(ts.chain_index, k)
@set! ts.Δ_index = Δ_index
@set! ts.chain_index = chain_index
end

# Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned
if adapt
ts.Ρ, ts.Δ = adapt_ladder(ts.Ρ, ts.Δ, k, swap_ar, n)
P, Δ = adapt_ladder(ts.Ρ, ts.Δ, k, min(one(logα), exp(logα)), n)
@set! ts.Ρ = P
@set! ts.Δ = Δ
end
return ts
end
end