-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General improvements and fixes #133
Changes from 4 commits
4d48a35
2df4a2b
83c5c49
9e9153c
26f12f1
aa88ae4
2523a9d
7462ebf
3fb13e9
90e60f2
e30718d
9e8607a
3bb79df
90722c9
50ad1ec
0c204ac
e2bbc90
e7466cc
0a59517
f7a7f31
696d8d1
bbb2fc2
8ac7374
f7c46e7
dcb2a0d
287b501
6239710
8c1b8ff
f70690f
afa2900
c0d9e61
4563d33
86fbf0d
229059b
c4583f9
a1efd11
c221572
632cca9
f30acfa
0a0b131
910dc6d
a4437f9
b9c6a90
79cbbf0
4336516
915b610
4898193
13cb3b2
cb6937e
db4b842
fce26aa
8a6b47e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,19 +41,16 @@ Chains: chain_index: Δ_index: | |
| | | | 2 3 1 4 3 1 2 4 | ||
| | | | | ||
""" | ||
mutable struct TemperedState | ||
states :: Array{Any} | ||
Δ :: Vector{<:Real} | ||
Δ_index :: Vector{<:Integer} | ||
chain_index :: Vector{<:Integer} | ||
step_counter :: Integer | ||
total_steps :: Integer | ||
Δ_history :: Array{<:Real, 2} | ||
Δ_index_history :: Array{<:Integer, 2} | ||
Ρ :: Vector{AdaptiveState} | ||
@concrete struct TemperedState | ||
states | ||
Δ | ||
Δ_index | ||
chain_index | ||
step_counter | ||
total_steps | ||
Ρ | ||
end | ||
|
||
|
||
""" | ||
For each `β` in `Δ`, carry out a step with a tempered model at the corresponding `β` inverse temperature, | ||
resulting in a list of transitions and states, the transition associated with `β₀ = 1` is then returned with the | ||
|
@@ -63,21 +60,23 @@ function AbstractMCMC.step( | |
rng::Random.AbstractRNG, | ||
model, | ||
spl::TemperedSampler; | ||
init_params=nothing, | ||
kwargs... | ||
) | ||
states = [ | ||
AbstractMCMC.step( | ||
rng, | ||
make_tempered_model(model, spl.Δ[spl.Δ_init[i]]), | ||
spl.internal_sampler; | ||
init_params=init_params !== nothing ? init_params[i] : nothing, | ||
kwargs... | ||
) | ||
for i in 1:length(spl.Δ) | ||
] | ||
return ( | ||
states[sortperm(spl.Δ_init)[1]][1], | ||
first(states[argmax(spl.Δ_init)]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be argmax or argmin? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it should be argmin. @torfjelde ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @torfjelde once I get a confirmation on whether it should be argmin or argmax, is it cool if I merge this and register an update to 0.2.0 (assuming tests pass)? Should simplify TuringLang/Turing.jl#1628 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But IMO we should just make it
Yep. There are several things that IMO should be done differently in that PR, e.g. we should be using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @torfjelde if you think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
no, it's argmin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done-diddit 👍 |
||
TemperedState( | ||
states,spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, Array{Real, 2}(spl.Δ'), Array{Integer, 2}(spl.Δ_init'), spl.Ρ | ||
states, spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, spl.Ρ | ||
) | ||
) | ||
end | ||
|
@@ -90,30 +89,29 @@ function AbstractMCMC.step( | |
) | ||
if ts.step_counter == spl.N_swap | ||
ts = swap_step(rng, model, spl, ts) | ||
ts.step_counter = 0 | ||
@set! ts.step_counter = 0 | ||
else | ||
ts.states = [ | ||
@set! ts.states = [ | ||
AbstractMCMC.step( | ||
rng, | ||
make_tempered_model(model, ts.Δ[ts.Δ_index[i]]), | ||
spl.internal_sampler, | ||
ts.states[i][2]; | ||
ts.states[ts.chain_index[i]][2]; | ||
kwargs... | ||
) | ||
for i in 1:length(ts.Δ) | ||
] | ||
ts.step_counter += 1 | ||
@set! ts.step_counter += 1 | ||
end | ||
|
||
ts.Δ_history = vcat(ts.Δ_history, Array{Real, 2}(ts.Δ')) | ||
ts.Δ_index_history = vcat(ts.Δ_index_history, Array{Integer, 2}(ts.Δ_index')) | ||
ts.total_steps += 1 | ||
return ts.states[ts.chain_index[1]][1], ts # Use chain_index[1] to ensure the sample from the target is always returned for the step | ||
@set! ts.total_steps += 1 | ||
# Use `chain_index[1]` to ensure the sample from the target is always returned for the step. | ||
return ts.states[ts.chain_index[1]][1], ts | ||
end | ||
|
||
|
||
""" | ||
swap_step(rng, model, spl, ts) | ||
swap_step([strategy::SwapStrategy, ]rng, model, spl, ts) | ||
|
||
Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - | ||
`ts` - to perform a "swap step" between temperatures, in accordance with the relevant | ||
|
@@ -122,35 +120,60 @@ swap strategy. | |
function swap_step( | ||
rng::Random.AbstractRNG, | ||
model, | ||
spl::TemperedSampler, | ||
sampler::TemperedSampler, | ||
ts::TemperedState | ||
) | ||
return swap_step(swapstrategy(sampler), rng, model, sampler, ts) | ||
end | ||
|
||
function swap_step( | ||
strategy::StandardSwap, | ||
rng::Random.AbstractRNG, | ||
model, | ||
sampler::TemperedSampler, | ||
ts::TemperedState | ||
) | ||
L = length(ts.Δ) - 1 | ||
sampler = spl.internal_sampler | ||
k = rand(rng, 1:L) | ||
return swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps / L) | ||
end | ||
|
||
if spl.swap_strategy == :standard | ||
function swap_step( | ||
strategy::RandomPermutationSwap, | ||
rng::Random.AbstractRNG, | ||
model, | ||
sampler::TemperedSampler, | ||
ts::TemperedState | ||
) | ||
L = length(ts.Δ) - 1 | ||
levels = Vector{Int}(undef, L) | ||
Random.randperm!(rng, levels) | ||
|
||
k = rand(rng, Distributions.Categorical(L)) # Pick randomly from 1, 2, ..., k - 1 | ||
ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps / L) | ||
# Iterate through all levels and attempt swaps. | ||
for k in levels | ||
ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) | ||
end | ||
return ts | ||
end | ||
|
||
function swap_step( | ||
strategy::NonReversibleSwap, | ||
rng::Random.AbstractRNG, | ||
model, | ||
sampler::TemperedSampler, | ||
ts::TemperedState | ||
) | ||
L = length(ts.Δ) - 1 | ||
# Alternate between swapping odds and evens. | ||
levels = if ts.total_steps % (2 * sampler.N_swap) == 0 | ||
1:2:L | ||
else | ||
2:2:L | ||
end | ||
|
||
# Define a vector to populate with levels at which to propose swaps according to swap_strategy | ||
levels = Vector{Int}(undef, L) | ||
if spl.swap_strategy == :nonrev | ||
if ts.step_counter % (2 * spl.N_swap) == 0 | ||
levels = 1:2:L | ||
else | ||
levels = 2:2:L | ||
end | ||
elseif spl.swap_strategy == :randperm | ||
randperm!(rng, levels) | ||
end | ||
|
||
for k in levels | ||
ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps) | ||
end | ||
|
||
# Iterate through all levels and attempt swaps. | ||
for k in levels | ||
ts = swap_attempt(rng, model, sampler.internal_sampler, ts, k, sampler.adapt, ts.total_steps) | ||
end | ||
return ts | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,50 @@ | ||
""" | ||
AbstractSwapStrategy | ||
|
||
Represents a strategy for swapping between parallel chains. | ||
|
||
A concrete subtype is expected to implement the method [`swap_step`](@ref). | ||
""" | ||
abstract type AbstractSwapStrategy end | ||
|
||
""" | ||
StandardSwap <: AbstractSwapStrategy | ||
|
||
At every swap step taken, this strategy samples a single chain index `i` and proposes | ||
a swap between chains `i` and `i + 1`. | ||
|
||
This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] | ||
|
||
The sampling of the chain index ensures reversibility/detailed balance is satisfied. | ||
|
||
# References | ||
[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). | ||
""" | ||
struct StandardSwap <: AbstractSwapStrategy end | ||
|
||
""" | ||
RandomPermutationSwap <: AbstractSwapStrategy | ||
|
||
At every swap step taken, this strategy randomly shuffles all the chain indices | ||
and then iterates through them, proposing swaps for neighboring chains. | ||
|
||
The shuffling of chain indices ensures reversibility/detailed balance is satisfied. | ||
""" | ||
struct RandomPermutationSwap <: AbstractSwapStrategy end | ||
|
||
|
||
""" | ||
NonReversibleSwap <: AbstractSwapStrategy | ||
|
||
At every swap step taken, this strategy _deterministically_ traverses first the | ||
odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step | ||
taken traverses even chain indices, proposing swaps between neighbors. | ||
|
||
Note that this method is _not_ reversible, and does not satisfy detailed balance. | ||
As a result, this method is asymptotically biased. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, not sure if this is the right place to put this, but this statement is not true. It is not reversible, but it and preserves the correct distribution. This non-reversibility is actually the main reason why doing something like this swapping strategy is better. The non-reversibility prevents the swapping from devolving into a diffusion-type process. The Syed 2019 paper goes into this in a lot of detail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ptiede Do you have a link to that paper? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes sorry! I should have included it. Here is the arXiv link https://arxiv.org/abs/1905.02939 A follow-up paper https://arxiv.org/abs/2102.07720 also goes into detail for how to further optimize PT. The primary author of that paper, @s-syed, and I were actually hoping to implement something like this into this package so if this repo is picking up again, we would love to help! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi! I am more than happy to answer any questions. This odd-even scheme is actually extremely important and in some sense achieves the optimal performance and should always be used over the reversible counterpart. Furthermore, the tuning guidelines are very different and results in significantly different tuning guidelines. In practice we find a 10-100x boost in performance compared to reversible swapping schemes and their corresponding tuning guidelines that have been implemented so far. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ah, thank you for pointing this out! This is indeed an incorrect statement. And thank you for pointing me to those papers! Just read the first one, and it's real neat stuff. And we'd be happy to collaborate on this package. Personally I'm just starting to get into PT samplers, so it would be awesome to include someone with both practical and theoretical experience with these things. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! I'd love to help. I have some experience implementing PT samplers for use with HPC and clusters, so I would be thrilled to help in any way I can! My old code was written in C++, so moving to Julia would be great. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome!:) I'll push some changes I've accumulated locally very soon. |
||
""" | ||
struct NonReversibleSwap <: AbstractSwapStrategy end | ||
|
||
""" | ||
swap_betas(chain_index, k) | ||
|
||
|
@@ -9,49 +56,24 @@ function swap_betas(chain_index, k) | |
return sortperm(chain_index), chain_index | ||
end | ||
|
||
function make_tempered_loglikelihood end | ||
function get_params end | ||
|
||
|
||
""" | ||
get_tempered_loglikelihoods_and_params(model, sampler, states, k, Δ, chain_index) | ||
compute_tempered_logdensities(model, sampler, transition, transition_other, β) | ||
|
||
Temper the `model`'s density using the `k`th and `k + 1`th temperatures | ||
selected via `Δ` and `chain_index`. Then retrieve the parameters using the chains' | ||
current transitions extracted from the collection of `states`. | ||
Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the | ||
log-density for `model` with inverse-temperature `β`. | ||
""" | ||
function get_tempered_loglikelihoods_and_params( | ||
model, | ||
sampler::AbstractMCMC.AbstractSampler, | ||
states, | ||
k::Integer, | ||
Δ::Vector{Real}, | ||
chain_index::Vector{<:Integer} | ||
) | ||
|
||
logπk = make_tempered_loglikelihood(model, Δ[k]) | ||
logπkp1 = make_tempered_loglikelihood(model, Δ[k + 1]) | ||
|
||
θk = get_params(states[chain_index[k]][1]) | ||
θkp1 = get_params(states[chain_index[k + 1]][1]) | ||
|
||
return logπk, logπkp1, θk, θkp1 | ||
end | ||
|
||
function compute_tempered_logdensities end | ||
|
||
""" | ||
swap_acceptance_pt(logπk, logπkp1, θk, θkp1) | ||
swap_acceptance_pt(logπk, logπkp1) | ||
|
||
Calculates and returns the swap acceptance ratio for swapping the temperature | ||
of two chains. Using tempered likelihoods `logπk` and `logπkp1` at the chains' | ||
current state parameters `θk` and `θkp1`. | ||
current state parameters. | ||
""" | ||
function swap_acceptance_pt(logπk, logπkp1, θk, θkp1) | ||
return min( | ||
1, | ||
exp(logπkp1(θk) + logπk(θkp1)) / exp(logπk(θk) + logπkp1(θkp1)) | ||
# exp(abs(βk - βkp1) * abs(AdvancedMH.logdensity(model, samplek) - AdvancedMH.logdensity(model, samplekp1))) | ||
) | ||
function swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) | ||
return (logπkp1_θk + logπk_θkp1) - (logπk_θk + logπkp1_θkp1) | ||
end | ||
|
||
|
||
|
@@ -61,21 +83,32 @@ end | |
Attempt to swap the temperatures of two chains by tempering the densities and | ||
calculating the swap acceptance ratio; then swapping if it is accepted. | ||
""" | ||
function swap_attempt(model, sampler, ts, k, adapt, n) | ||
|
||
logπk, logπkp1, θk, θkp1 = get_tempered_loglikelihoods_and_params(model, sampler, ts.states, k, ts.Δ, ts.chain_index) | ||
function swap_attempt(rng, model, sampler, ts, k, adapt, n) | ||
# Extract the relevant transitions. | ||
transitionk = first(ts.states[ts.chain_index[k]]) | ||
transitionkp1 = first(ts.states[ts.chain_index[k + 1]]) | ||
# Evaluate logdensity for both parameters for each tempered density. | ||
logπk_θk, logπk_θkp1 = compute_tempered_logdensities( | ||
model, sampler, transitionk, transitionkp1, ts.Δ[k] | ||
) | ||
logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( | ||
model, sampler, transitionkp1, transitionk, ts.Δ[k + 1] | ||
) | ||
|
||
swap_ar = swap_acceptance_pt(logπk, logπkp1, θk, θkp1) | ||
U = rand(Distributions.Uniform(0, 1)) | ||
|
||
# If the proposed temperature swap is accepted according to swap_ar and U, swap the temperatures for future steps | ||
if U ≤ swap_ar | ||
ts.Δ_index, ts.chain_index = swap_betas(ts.chain_index, k) | ||
# If the proposed temperature swap is accepted according `logα`, | ||
# swap the temperatures for future steps. | ||
logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) | ||
if -Random.randexp(rng) ≤ logα | ||
Δ_index, chain_index = swap_betas(ts.chain_index, k) | ||
@set! ts.Δ_index = Δ_index | ||
@set! ts.chain_index = chain_index | ||
end | ||
|
||
# Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned | ||
if adapt | ||
ts.Ρ, ts.Δ = adapt_ladder(ts.Ρ, ts.Δ, k, swap_ar, n) | ||
P, Δ = adapt_ladder(ts.Ρ, ts.Δ, k, min(one(logα), exp(logα)), n) | ||
@set! ts.Ρ = P | ||
@set! ts.Δ = Δ | ||
end | ||
return ts | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite sure about the use of ConcreteStructs. I found the type annotations helpful for understanding the code. Also, we would like to keep the dependency absolutely minimal where possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to copy and paste my macro from https://github.com/JuliaNonconvex/NonconvexCore.jl/blob/master/src/utilities/params.jl where you can keep the type annotations and make it concrete at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With regards to minimizing dependencies, ConcreteStructs seems to take .004 seconds to load on my computer, so I don't think it's that big a deal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mohamed82008 Do you think that might be a good macro to add to ConcreteStructs.jl (so it can be used more easily outside of NonconvexCore)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes if the authors of ConcreteStructs are ok with the move.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I get this sentiment, I think in general you want to avoid putting explicit type-constraints on parameteric types, in particular here where IMO the constraints where already stronger than necessary.
Oh this looks quite nice! Might be worth doing that instead:)
Highly recommend you propose it! From a quick glance it just seems superior to
@concrete
in any way, no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know about
@concrete
when I did this. I think mine is older (from years ago).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But IMO we shouldn't even restrict to something like a
AbstractVector
, e.g. for a small number of temperatures it might be better to use aTuple
, etc. Instead we should just document the fields properly.