diff --git a/.gitignore b/.gitignore index fcbca32..534f7b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -/Manifest.toml +Manifest.toml *.DS_Store -*.png -deprecated +working_code diff --git a/Project.toml b/Project.toml index fa54387..af97b01 100644 --- a/Project.toml +++ b/Project.toml @@ -5,16 +5,19 @@ version = "0.1.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] -AbstractMCMC = "3.2" +AbstractMCMC = "3.2, 4" +ConcreteStructs = "0.2" Distributions = "0.24, 0.25" +DocStringExtensions = "0.8, 0.9" +InverseFunctions = "0.1" +Setfield = "0.7, 0.8, 1" julia = "1" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index 36041f8..1752dfc 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -4,15 +4,30 @@ import AbstractMCMC import Distributions import Random +using ProgressLogging: ProgressLogging +using ConcreteStructs: @concrete +using Setfield: @set, @set! + +using InverseFunctions + +using DocStringExtensions + include("adaptation.jl") -include("tempered.jl") +include("swapping.jl") +include("state.jl") +include("sampler.jl") +include("sampling.jl") include("ladders.jl") include("stepping.jl") include("model.jl") -include("swapping.jl") -include("plotting.jl") -export tempered, TemperedSampler, plot_swaps, plot_ladders, make_tempered_model, get_tempered_loglikelihoods_and_params, make_tempered_loglikelihood, get_params +export tempered, + tempered_sample, + TemperedSampler, + make_tempered_model, + StandardSwap, + RandomPermutationSwap, + NonReversibleSwap function AbstractMCMC.bundle_samples( ts::Vector, @@ -22,7 +37,10 @@ function AbstractMCMC.bundle_samples( chain_type::Type; kwargs... ) - AbstractMCMC.bundle_samples(ts, model, sampler.internal_sampler, state, chain_type; kwargs...) + AbstractMCMC.bundle_samples( + ts, model, sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type; + kwargs... + ) end end diff --git a/src/adaptation.jl b/src/adaptation.jl index 31927b7..5d80a9d 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -1,54 +1,240 @@ +using Distributions: StatsFuns -struct PolynomialStep - η :: Real - c :: Real +@concrete struct PolynomialStep + η + c end function get(step::PolynomialStep, k::Real) - step.c * (k + 1.) ^ (-step.η) + return step.c * (k + 1) ^ (-step.η) end +""" + Geometric -struct AdaptiveState - swap_target_ar :: Real - scale :: Base.RefValue{<:Real} - step :: PolynomialStep +Specifies a geometric schedule for the inverse temperatures. + +See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and +[`weight`](@ref). +""" +struct Geometric end + +defaultscale(::Geometric, Δ) = eltype(Δ)(0.9) + +""" + InverselyAdditive + +Specifies an additive schedule for the temperatures (not _inverse_ temperatures). + +See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and +[`weight`](@ref). +""" +struct InverselyAdditive end + +defaultscale(::InverselyAdditive, Δ) = eltype(Δ)(0.9) + +struct AdaptiveState{S,T1<:Real,T2<:Real,P<:PolynomialStep} + schedule_type::S + swap_target_ar::T1 + scale_unconstrained::T2 + step::P + n::Int end -function AdaptiveState(swap_target::Real, scale::Real, step::PolynomialStep) - AdaptiveState(swap_target, Ref(log(scale)), step) + +function AdaptiveState(swap_target_ar, scale_unconstrained, step) + return AdaptiveState(InverselyAdditive(), swap_target_ar, scale_unconstrained, step) +end + +function AdaptiveState(schedule_type, swap_target_ar, scale_unconstrained, step) + return AdaptiveState(schedule_type, swap_target_ar, scale_unconstrained, step, 1) end +""" + weight(ρ::AdaptiveState{<:Geometric}) + +Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. + +# Notes +In Eq. (13) in [^MIAS12] they use the relation + + β[ℓ + 1] = β[ℓ] * w(ρ) + +with + + w(ρ) = exp(-exp(ρ)) + +because we want `w(ρ) ∈ (0, 1)` while `ρ ∈ ℝ`. As an alternative, we use +`StatsFuns.logistic(ρ)` which is numerically more stable than `exp(-exp(ρ))` and +leads to less extreme values, i.e. 0 or 1. + +This the same approach as mentioned in [^ATCH11]. + +# References +[^MIAS12]: Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +[^ATCH11]: Atchade, Yves F, Roberts, G. O., & Rosenthal, J. S., Towards optimal scaling of metropolis-coupled markov chain monte carlo, Statistics and Computing, 21(4), 555–568 (2011). +""" +weight(ρ::AdaptiveState{<:Geometric}) = geometric_weight_constrain(ρ.scale_unconstrained) +geometric_weight_constrain(x) = StatsFuns.logistic(x) +geometric_weight_unconstrain(y) = inverse(StatsFuns.logistic)(y) + +""" + weight(ρ::AdaptiveState{<:InverselyAdditive}) + +Return the weight/scale to be used in the mapping `β[ℓ] ↦ β[ℓ + 1]`. +""" +weight(ρ::AdaptiveState{<:InverselyAdditive}) = inversely_additive_weight_constrain(ρ.scale_unconstrained) +inversely_additive_weight_constrain(x) = exp(-x) +inversely_additive_weight_unconstrain(y) = -log(y) + +function init_adaptation( + schedule::InverselyAdditive, + Δ::Vector{<:Real}, + swap_target::Real, + scale::Real, + η::Real, + stepsize::Real +) + N_it = length(Δ) + step = PolynomialStep(η, stepsize) + # TODO: One common state or one per temperature? + # ρs = [ + # AdaptiveState(schedule, swap_target, inversely_additive_weight_unconstrain(scale), step) + # for _ in 1:(N_it - 1) + # ] + ρs = AdaptiveState(schedule, swap_target, log(scale), step) + return ρs +end function init_adaptation( + schedule::Geometric, Δ::Vector{<:Real}, swap_target::Real, scale::Real, - γ::Real + η::Real, + stepsize::Real ) - Nt = length(Δ) - step = PolynomialStep(γ, Nt - 1) - Ρ = [AdaptiveState(swap_target, scale, step) for _ in 1:(Nt - 1)] - return Ρ + N_it = length(Δ) + step = PolynomialStep(η, stepsize) + # TODO: One common state or one per temperature? + # ρs = [ + # AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) + # for _ in 1:(N_it - 1) + # ] + ρs = AdaptiveState(schedule, swap_target, geometric_weight_unconstrain(scale), step) + return ρs end -function rhos_to_ladder(Ρ, Δ) - β′ = Δ[1] - for i in 1:length(Ρ) - β′ += exp(Ρ[i].scale[]) - Δ[i + 1] = Δ[1] / β′ +""" + adapt!!(ρ::AdaptiveState, swap_ar) + +Return updated `ρ` based on swap acceptance ratio `swap_ar`. + +See [`update_inverse_temperatures`](@ref) to see how we compute the resulting +inverse temperatures from the adapted state `ρ`. +""" +function adapt!!(ρ::AdaptiveState, swap_ar) + swap_diff = ρ.swap_target_ar - swap_ar + γ = get(ρ.step, ρ.n) + ρ_new = @set ρ.scale_unconstrained = ρ.scale_unconstrained + γ * swap_diff + return @set ρ_new.n += 1 +end + +""" + adapt!!(ρ::AdaptiveState, Δ, k, swap_ar) + adapt!!(ρ::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar) + +Return adapted state(s) given that we just proposed a swap of the `k`-th +and `(k + 1)`-th temperatures with acceptance ratio `swap_ar`. +""" +adapt!!(ρ::AdaptiveState, Δ, k, swap_ar) = adapt!!(ρ, swap_ar) +function adapt!!(ρs::AbstractVector{<:AdaptiveState}, Δ, k, swap_ar) + ρs[k] = adapt!!(ρs[k], swap_ar) + return ρs +end + +""" + update_inverse_temperatures(ρ::AdaptiveState{<:Geometric}, Δ_current) + update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState{<:Geometric}}, Δ_current) + +Return updated inverse temperatures computed from adaptation state(s) and `Δ_current`. + +This update is similar to Eq. (13) in [^MIAS12], with the only possible deviation +being how we compute the scaling factor from `ρ`: see [`weight`](@ref) for information. + +If `ρ` is a `AbstractVector`, then it should be of length `length(Δ_current) - 1`, +with `ρ[k]` corresponding to the adaptation state for the `k`-th inverse temperature. + +# References +[^MIAS12]: Miasojedow, B., Moulines, E., & Vihola, M., Adaptive Parallel Tempering Algorithm, (2012). +""" +function update_inverse_temperatures(ρ::AdaptiveState{<:Geometric}, Δ_current) + Δ = similar(Δ_current) + β₀ = Δ_current[1] + Δ[1] = β₀ + + β = inv(β₀) + for ℓ in 1:length(Δ) - 1 + # TODO: Is it worth it to do this on log-scale instead? + β *= weight(ρ) + @inbounds Δ[ℓ + 1] = β end return Δ end +function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState{<:Geometric}}, Δ_current) + Δ = similar(Δ_current) + N = length(Δ) + @assert length(ρs) ≥ N - 1 "number of adaptive states < number of temperatures" + + β₀ = Δ_current[1] + Δ[1] = β₀ -function adapt_rho(ρ::AdaptiveState, swap_ar, n) - swap_diff = swap_ar - ρ.swap_target_ar - γ = get(ρ.step, n) - return γ * swap_diff + β = β₀ + for ℓ in 1:N - 1 + # TODO: Is it worth it to do this on log-scale instead? + β *= weight(ρs[ℓ]) + @inbounds Δ[ℓ + 1] = β + end + return Δ end +""" + update_inverse_temperatures(ρ::AdaptiveState{<:InverselyAdditive}, Δ_current) + update_inverse_temperatures(ρ::AbstractVector{<:AdaptiveState{<:InverselyAdditive}}, Δ_current) + +Return updated inverse temperatures computed from adaptation state(s) and `Δ_current`. + +This update increments the temperature (not _inverse_ temperature) by a positive constant, +which is adapted by `ρ`. + +If `ρ` is a `AbstractVector`, then it should be of length `length(Δ_current) - 1`, +with `ρ[k]` corresponding to the adaptation state for the `k`-th inverse temperature. +""" +function update_inverse_temperatures(ρ::AdaptiveState{<:InverselyAdditive}, Δ_current) + Δ = similar(Δ_current) + β₀ = Δ_current[1] + Δ[1] = β₀ -function adapt_ladder(Ρ, Δ, k, swap_ar, n) - Ρ[k].scale[] += adapt_rho(Ρ[k], swap_ar, n) - return Ρ, rhos_to_ladder(Ρ, Δ) -end \ No newline at end of file + T = inv(β₀) + for ℓ in 1:length(Δ) - 1 + T += weight(ρ) + @inbounds Δ[ℓ + 1] = inv(T) + end + return Δ +end + +function update_inverse_temperatures(ρs::AbstractVector{<:AdaptiveState{<:InverselyAdditive}}, Δ_current) + Δ = similar(Δ_current) + N = length(Δ) + @assert length(ρs) ≥ N - 1 "number of adaptive states < number of temperatures" + + β₀ = Δ_current[1] + Δ[1] = β₀ + + T = inv(β₀) + for ℓ in 1:N - 1 + T += weight(ρs[ℓ]) + @inbounds Δ[ℓ + 1] = inv(T) + end + return Δ +end diff --git a/src/ladders.jl b/src/ladders.jl index 2017764..2cccab5 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -1,52 +1,51 @@ """ - get_scaling_val(Nt, swap_strategy) + get_scaling_val(N_it, swap_strategy) -Calculates the correct scaling factor for polynomial step size between temperatures +Calculates a scaling factor for polynomial step size between inverse temperatures. """ -function get_scaling_val(Nt, swap_strategy) - # Why these? - if swap_strategy == :standard - scaling_val = Nt - 1 - elseif swap_strategy == :nonrev - scaling_val = 2 - else - scaling_val = 1 - end - return scaling_val -end - +get_scaling_val(N_it, ::StandardSwap) = N_it - 1 +get_scaling_val(N_it, ::NonReversibleSwap) = 2 +get_scaling_val(N_it, ::RandomPermutationSwap) = 1 """ - generate_Δ(Nt, swap_strategy) + generate_inverse_temperatures(N_it, swap_strategy) -Returns a temperature ladder `Δ` containing `Nt` temperatures, -generated in accordance with the chosen `swap_strategy` +Returns a temperature ladder `Δ` containing `N_it` values, +generated in accordance with the chosen `swap_strategy`. """ -function generate_Δ(Nt, swap_strategy) - scaling_val = get_scaling_val(Nt, swap_strategy) - Δ = zeros(Real, Nt) - Δ[1] = 1.0 - β′ = Δ[1] - for i ∈ 1:(Nt - 1) - β′ += exp(scaling_val) - Δ[i + 1] = Δ[1] / β′ +function generate_inverse_temperatures(N_it, swap_strategy) + # Apparently, here we increase the temperature by a constant + # factor which depends on `swap_strategy`. + scaling_val = get_scaling_val(N_it, swap_strategy) + Δ = Vector{Float64}(undef, N_it) + Δ[1] = 1 + T = Δ[1] + for i in 1:(N_it - 1) + T += scaling_val + Δ[i + 1] = inv(T) end return Δ end """ - check_Δ(Δ) + check_inverse_temperatures(Δ) Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0` """ -function check_Δ(Δ) +function check_inverse_temperatures(Δ) + if length(Δ) <= 1 + error("More than one inverse temperatures must be provided.") + end if !all(zero.(Δ) .≤ Δ .≤ one.(Δ)) - error("Temperature schedule provided has values outside of the acceptable range, ensure all values are in [0, 1].") + error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].") end - Δ = sort(Δ; rev=true) - if Δ[1] != one(Δ[1]) - error("Δ must contain 1, as β₀.") + Δ_sorted = sort(Δ; rev=true) + if Δ_sorted[1] != one(Δ_sorted[1]) + error("The temperature ladder must contain 1.") end - return Δ + if Δ_sorted != Δ + println("The temperature was sorted to ensure decreasing order.") + end + return Δ_sorted end diff --git a/src/model.jl b/src/model.jl index 4e1a5ca..38f39d0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,8 +1,9 @@ +""" + make_tempered_model(sampler, model, args...) -# struct TemperedModel <: AbstractPPL.AbstractProbabilisticProgram -# model :: DynamicPPL.Model -# β :: AbstractFloat -# end +Return an instance representing a model. +The return-type depends on its usage in [`compute_tempered_logdensities`](@ref). +""" function make_tempered_model end diff --git a/src/plotting.jl b/src/plotting.jl deleted file mode 100644 index 21c2a8d..0000000 --- a/src/plotting.jl +++ /dev/null @@ -1,12 +0,0 @@ - -""" -When sample is called with the `save_state` kwarg set to `true`, the chain can be used to plot the tempering swaps that occurred during sampling -""" -function plot_swaps(chain) - plot(chain.info.samplerstate.Δ_index_history) -end - - -function plot_ladders(chain) - plot(chain.info.samplerstate.Δ_history) -end diff --git a/src/sampler.jl b/src/sampler.jl new file mode 100644 index 0000000..f180de9 --- /dev/null +++ b/src/sampler.jl @@ -0,0 +1,113 @@ +""" + TemperedSampler <: AbstractMCMC.AbstractSampler + +A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Tempering algorithm. + +# Fields + +$(FIELDS) +""" +@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler + "sampler(s) used to target the tempered distributions" + sampler + "collection of inverse temperatures β; β[i] correponds i-th tempered model" + inverse_temperatures + "number of steps of `sampler` to take before proposing swaps" + swap_every + "the swap strategy that will be used when proposing swaps" + swap_strategy + # TODO: This should be replaced with `P` just being some `NoAdapt` type. + "boolean flag specifying whether or not to adapt" + adapt + "adaptation parameters" + adaptation_states +end + +swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy + +getsampler(samplers, I...) = getindex(samplers, I...) +getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler +getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...) + +""" + numsteps(sampler::TemperedSampler) + +Return number of inverse temperatures used by `sampler`. +""" +numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures) + +""" + sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) + +Return the sampler corresponding to the chain indexed by `I...`. +If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. +""" +sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) +function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.sampler, state.chain_to_process[I...]) +end + +""" + sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) + +Return the sampler corresponding to the process indexed by `I...`. +""" +function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) + return getsampler(sampler.sampler, I...) +end + +""" + tempered(sampler, inverse_temperatures; kwargs...) + OR + tempered(sampler, N_it; swap_strategy=StandardSwap(), kwargs...) + +Return a tempered version of `sampler` using the provided `inverse_temperatures` or +inverse temperatures generated from `N_it` and the `swap_strategy`. + +# Arguments +- `sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to +- The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: + - `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 + OR + - `N_it`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures` + +# Keyword arguments +- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_every::Integer` steps are carried out between each attempt at a swap + +# See also +- [`TemperedSampler`](@ref) +- For more on the swap strategies: + - [`AbstractSwapStrategy`](@ref) + - [`StandardSwap`](@ref) + - [`NonReversibleSwap`](@ref) + - [`RandomPermutationSwap`](@ref) +""" +function tempered( + sampler::AbstractMCMC.AbstractSampler, + N_it::Integer; + swap_strategy::AbstractSwapStrategy=StandardSwap(), + kwargs... +) + return tempered(sampler, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy = swap_strategy, kwargs...) +end +function tempered( + sampler::AbstractMCMC.AbstractSampler, + inverse_temperatures::Vector{<:Real}; + swap_strategy::AbstractSwapStrategy=StandardSwap(), + swap_every::Integer=2, + adapt::Bool=false, + adapt_target::Real=0.234, + adapt_stepsize::Real=1, + adapt_eta::Real=0.66, + adapt_schedule=Geometric(), + adapt_scale=defaultscale(adapt_schedule, inverse_temperatures), + kwargs... +) + swap_every >= 2 || error("This must be a positive integer greater than 1.") + inverse_temperatures = check_inverse_temperatures(inverse_temperatures) + adaptation_states = init_adaptation( + adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize + ) + return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) +end diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 0000000..733158b --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,55 @@ +""" + tempered_sample([rng, ], model, sampler, N, inverse_temperatures; kwargs...) + OR + tempered_sample([rng, ], model, sampler, N, N_it; swap_strategy=StandardSwap(), kwargs...) + +Generate `N` samples from `model` using a tempered version of the provided `sampler`. +Provide either `inverse_temperatures` or `N_it` (and a `swap_strategy`) to generate some + +# Keyword arguments +- `N_burnin::Integer` burn-in steps will be carried out before any swapping between chains is attempted +- `swap_strategy::AbstractSwapStrategy` is the way in which inverse temperature swaps between chains are made +- `swap_every::Integer` steps are carried out between each attempt at a swap + +# See also +- [`tempered`](@ref) +- [`TemperedSampler`](@ref) +- For more on the swap strategies: + - [`AbstractSwapStrategy`](@ref) + - [`StandardSwap`](@ref) + - [`NonReversibleSwap`](@ref) + - [`RandomPermutationSwap`](@ref) +""" +function tempered_sample( + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + arg::Union{Integer, Vector{<:Real}}; + kwargs... +) + return tempered_sample(Random.default_rng(), model, sampler, N, arg; kwargs...) +end + +function tempered_sample( + rng, + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + N_it::Integer; + swap_strategy::AbstractSwapStrategy = StandardSwap(), + kwargs... +) + return tempered_sample(model, sampler, N, generate_inverse_temperatures(N_it, swap_strategy); swap_strategy=swap_strategy, kwargs...) +end + +function tempered_sample( + rng, + model, + sampler::AbstractMCMC.AbstractSampler, + N::Integer, + inverse_temperatures::Vector{<:Real}; + kwargs... +) + tempered_sampler = tempered(sampler, inverse_temperatures; kwargs...) + return AbstractMCMC.sample(rng, model, tempered_sampler, N; kwargs...) +end \ No newline at end of file diff --git a/src/state.jl b/src/state.jl new file mode 100644 index 0000000..12be39b --- /dev/null +++ b/src/state.jl @@ -0,0 +1,131 @@ +""" + TemperedState + +A general implementation of a state for a [`TemperedSampler`](@ref). + +# Fields + +$(FIELDS) + +# Description + +Suppose we're running 4 chains `X`, `Y`, `Z`, and `W`, each targeting a distribution for different +(inverse) temperatures `β`, say, `1.0`, `0.75`, `0.5`, and `0.25`, respectively. That is, we're mainly +interested in the chain `(X[1], X[2], … )` which targets the distribution with `β=1.0`. + +Moreover, suppose we also have 4 workers/processes for which we run these chains in "parallel" +(can also be serial wlog). + +We can then perform a swap in two different ways: +1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. +2. Swap the _temperatures_ between each process, i.e. permute `inverse_temperatures`. + +(1) is possibly the most intuitive approach since it means that the i-th worker/process +corresponds to the i-th chain; in this case, process 1 corresponds to `X`, process 2 to `Y`, etc. +The downside is that we need to move (potentially high-dimensional) states between the +workers/processes. + +(2) on the other hand does _not_ preserve the direct process-chain correspondance. +We now need to keep track of which process has which chain, from this we can +reconstruct each of the chains `X`, `Y`, etc. afterwards. +This means that we need only transfer a pair of numbers representing the (inverse) +temperatures between workers rather than the full states. + +This implementation follows approach (2). + +Here's an exemplar realisation of five steps of sampling and swap-attempts: + +``` +Chains: process_to_chain chain_to_process inverse_temperatures[process_to_chain[i]] +| | | | 1 2 3 4 1 2 3 4 1.00 0.75 0.50 0.25 +| | | | + V | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 + Λ | | +| | | | 2 1 3 4 2 1 3 4 0.75 1.00 0.50 0.25 +| | | | +| V | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| Λ | +| | | | 2 3 1 4 3 1 2 4 0.75 0.50 1.00 0.25 +| | | | +``` + +In this case, the chain `X` can be reconstructed as: + +```julia +X[1] = states[1].transitions_and_states[1] +X[2] = states[2].transitions_and_states[2] +X[3] = states[3].transitions_and_states[2] +X[4] = states[4].transitions_and_states[3] +X[5] = states[5].transitions_and_states[3] +``` + +The indices here are exactly those represented by `states[k].chain_to_process[1]`. +""" +@concrete struct TemperedState + "collection of `(transition, state)` pairs for each process" + transitions_and_states + "collection of (inverse) temperatures β corresponding to each process" + inverse_temperatures + "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" + chain_to_process + "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" + process_to_chain + "total number of steps taken" + total_steps + "number of burn-in steps taken" + burnin_steps + "contains all necessary information for adaptation of inverse_temperatures" + adaptation_states + "flag which specifies wether this was a swap-step or not" + is_swap + "swap acceptance ratios on log-scale" + swap_acceptance_ratios +end + +""" + transition_for_chain(state[, I...]) + +Return the transition corresponding to the chain indexed by `I...`. +If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. +""" +transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) +transition_for_chain(state::TemperedState, I...) = state.transitions_and_states[state.chain_to_process[I...]][1] + +""" + transition_for_process(state, I...) + +Return the transition corresponding to the process indexed by `I...`. +""" +transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] + +""" + state_for_chain(state[, I...]) + +Return the state corresponding to the chain indexed by `I...`. +If `I...` is not specified, the state corresponding to `β=1.0` will be returned. +""" +state_for_chain(state::TemperedState) = state_for_chain(state, 1) +state_for_chain(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + state_for_process(state, I...) + +Return the state corresponding to the process indexed by `I...`. +""" +state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] + +""" + β_for_chain(state[, I...]) + +Return the β corresponding to the chain indexed by `I...`. +If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +""" +β_for_chain(state::TemperedState) = β_for_chain(state, 1) +β_for_chain(state::TemperedState, I...) = state.inverse_temperatures[state.chain_to_process[I...]] + +""" + β_for_process(state, I...) + +Return the β corresponding to the process indexed by `I...`. +""" +β_for_process(state::TemperedState, I...) = state.inverse_temperatures[I...] \ No newline at end of file diff --git a/src/stepping.jl b/src/stepping.jl index 3521e81..33af093 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,156 +1,199 @@ """ - mutable struct TemperedState - states :: Array{Any} - Δ :: Vector{<:Real} - Δ_index :: Vector{<:Integer} - chain_index :: Vector{<:Integer} - step_counter :: Integer - total_steps :: Integer - Δ_history :: Array{<:Real, 2} - Δ_index_history :: Array{<:Integer, 2} - Ρ :: Vector{AdaptiveState} - end + should_swap(sampler, state) -A `TemperedState` struct contains the `states` of each of the parallel chains -used throughout parallel tempering as pairs of `Transition`s and `VarInfo`s, -it also stores necessary information for tempering: -- `states` is an Array of pairs of `Transition`s and `VarInfo`s, one for each - tempered chain -- `Δ` contains the ordered sequence of inverse temperatures -- `Δ_index` contains the current ordering to apply the temperatures to each chain, tracking swaps, - i.e., contains the index `Δ_index[i] = j` of the temperature in `Δ`, `Δ[j]`, to apply to chain `i` -- `chain_index` contains the index `chain_index[i] = k` of the chain tempered by `Δ[i]` - NOTE that to convert between this and `Δ_index` we simply use the `sortperm()` function -- `step_counter` maintains the number of steps taken since the last swap attempt -- `total_steps` maintains the count of the total number of steps taken -- `Δ_index_history` records the history of swaps that occur in sampling by recording the `Δ_index` at each step -- `Δ_history` records the values of the inverse temperatures, these will change if adaptation is being used -- `Ρ` contains all of the information required for adaptation of Δ - -Example of swaps across 4 chains and the values of `chain_index` and `Δ_index`: - -Chains: chain_index: Δ_index: -| | | | 1 2 3 4 1 2 3 4 -| | | | - V | | 2 1 3 4 2 1 3 4 - Λ | | -| | | | 2 1 3 4 2 1 3 4 -| | | | -| V | 2 3 1 4 3 1 2 4 -| Λ | -| | | | 2 3 1 4 3 1 2 4 -| | | | +Return `true` if a swap should happen at this iteration, and `false` otherwise. """ -mutable struct TemperedState - states :: Array{Any} - Δ :: Vector{<:Real} - Δ_index :: Vector{<:Integer} - chain_index :: Vector{<:Integer} - step_counter :: Integer - total_steps :: Integer - Δ_history :: Array{<:Real, 2} - Δ_index_history :: Array{<:Integer, 2} - Ρ :: Vector{AdaptiveState} +function should_swap(sampler::TemperedSampler, state::TemperedState) + return state.total_steps % sampler.swap_every == 0 end - -""" -For each `β` in `Δ`, carry out a step with a tempered model at the corresponding `β` inverse temperature, -resulting in a list of transitions and states, the transition associated with `β₀ = 1` is then returned with the -rest of the information being stored in the state. -""" function AbstractMCMC.step( rng::Random.AbstractRNG, model, - spl::TemperedSampler; + sampler::TemperedSampler; + N_burnin::Integer=0, + burnin_progress::Bool=AbstractMCMC.PROGRESS[], + init_params=nothing, kwargs... ) - states = [ + + # `TemperedState` has the transitions and states in the order of + # the processes, and performs swaps by moving the (inverse) temperatures + # `β` between the processes, rather than moving states between processes + # and keeping the `β` local to each process. + # + # Therefore we iterate over the processes and then extract the corresponding + # `β`, `sampler` and `state`, and take a initialize. + transitions_and_states = [ AbstractMCMC.step( rng, - make_tempered_model(model, spl.Δ[spl.Δ_init[i]]), - spl.internal_sampler; + make_tempered_model(sampler, model, sampler.inverse_temperatures[i]), + getsampler(sampler, i); + init_params=init_params !== nothing ? init_params[i] : nothing, kwargs... ) - for i in 1:length(spl.Δ) + for i in 1:numtemps(sampler) ] - return ( - states[sortperm(spl.Δ_init)[1]][1], - TemperedState( - states,spl.Δ, spl.Δ_init, sortperm(spl.Δ_init), 1, 1, Array{Real, 2}(spl.Δ'), Array{Integer, 2}(spl.Δ_init'), spl.Ρ - ) + + # Make sure to collect, because we'll be using `setindex!(!)` later. + process_to_chain = collect(1:length(sampler.inverse_temperatures)) + # Need to `copy` because this might be mutated. + chain_to_process = copy(process_to_chain) + state = TemperedState( + transitions_and_states, + sampler.inverse_temperatures, + process_to_chain, + chain_to_process, + 1, + 0, + sampler.adaptation_states, + false, + Dict{Int,Float64}() ) + + if N_burnin > 0 + AbstractMCMC.@ifwithprogresslogger burnin_progress name = "Burn-in" begin + # Determine threshold values for progress logging + # (one update per 0.5% of progress) + if burnin_progress + threshold = N_burnin ÷ 200 + next_update = threshold + end + + for i in 1:N_burnin + if burnin_progress && i >= next_update + ProgressLogging.@logprogress i / N_burnin + next_update = i + threshold + end + state = no_swap_step(rng, model, sampler, state; kwargs...) + @set! state.burnin_steps += 1 + end + end + end + + return transition_for_chain(state), state end + function AbstractMCMC.step( rng::Random.AbstractRNG, model, - spl::TemperedSampler, - ts::TemperedState; + sampler::TemperedSampler, + state::TemperedState; kwargs... ) - if ts.step_counter == spl.N_swap - ts = swap_step(rng, model, spl, ts) - ts.step_counter = 0 + # Reset. + @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) + + if should_swap(sampler, state) + state = swap_step(rng, model, sampler, state) + @set! state.is_swap = true else - ts.states = [ - AbstractMCMC.step( - rng, - make_tempered_model(model, ts.Δ[ts.Δ_index[i]]), - spl.internal_sampler, - ts.states[i][2]; - kwargs... - ) - for i in 1:length(ts.Δ) - ] - ts.step_counter += 1 + state = no_swap_step(rng, model, sampler, state; kwargs...) + @set! state.is_swap = false end - ts.Δ_history = vcat(ts.Δ_history, Array{Real, 2}(ts.Δ')) - ts.Δ_index_history = vcat(ts.Δ_index_history, Array{Integer, 2}(ts.Δ_index')) - ts.total_steps += 1 - return ts.states[ts.chain_index[1]][1], ts # Use chain_index[1] to ensure the sample from the target is always returned for the step + @set! state.total_steps += 1 + + # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. + return transition_for_chain(state), state end +function no_swap_step( + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState; + kwargs... +) + # `TemperedState` has the transitions and states in the order of + # the processes, and performs swaps by moving the (inverse) temperatures + # `β` between the processes, rather than moving states between processes + # and keeping the `β` local to each process. + # + # Therefore we iterate over the processes and then extract the corresponding + # `β`, `sampler` and `state`, and take a step. + @set! state.transitions_and_states = [ + AbstractMCMC.step( + rng, + make_tempered_model(sampler, model, β_for_process(state, i)), + sampler_for_process(sampler, state, i), + state_for_process(state, i); + kwargs... + ) + for i in 1:numtemps(sampler) + ] + + return state +end """ - swap_step(rng, model, spl, ts) + swap_step([strategy::AbstractSwapStrategy, ]rng, model, sampler, state) -Uses the internals of the passed `TemperedSampler` - `spl` - and `TemperedState` - -`ts` - to perform a "swap step" between temperatures, in accordance with the relevant -swap strategy. +Return new `state`, now with temperatures swapped according to `strategy`. + +If no `strategy` is provided, the return-value of [`swapstrategy`](@ref) called on `sampler` +is used. """ function swap_step( rng::Random.AbstractRNG, model, - spl::TemperedSampler, - ts::TemperedState + sampler::TemperedSampler, + state::TemperedState ) - L = length(ts.Δ) - 1 - sampler = spl.internal_sampler - - if spl.swap_strategy == :standard + return swap_step(swapstrategy(sampler), rng, model, sampler, state) +end - k = rand(rng, Distributions.Categorical(L)) # Pick randomly from 1, 2, ..., k - 1 - ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps / L) +function swap_step( + strategy::StandardSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + L = numtemps(sampler) - 1 + k = rand(rng, 1:L) + return swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps / L) +end - else +function swap_step( + strategy::RandomPermutationSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + L = numtemps(sampler) - 1 + levels = Vector{Int}(undef, L) + Random.randperm!(rng, levels) - # Define a vector to populate with levels at which to propose swaps according to swap_strategy - levels = Vector{Int}(undef, L) - if spl.swap_strategy == :nonrev - if ts.step_counter % (2 * spl.N_swap) == 0 - levels = 1:2:L - else - levels = 2:2:L - end - elseif spl.swap_strategy == :randperm - randperm!(rng, levels) - end + # Iterate through all levels and attempt swaps. + for k in levels + state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) + end + return state +end - for k in levels - ts = swap_attempt(model, sampler, ts, k, spl.adapt, ts.total_steps) - end +function swap_step( + strategy::NonReversibleSwap, + rng::Random.AbstractRNG, + model, + sampler::TemperedSampler, + state::TemperedState +) + L = numtemps(sampler) - 1 + # Alternate between swapping odds and evens. + levels = if state.total_steps % (2 * sampler.swap_every) == 0 + 1:2:L + else + 2:2:L + end + # Iterate through all levels and attempt swaps. + for k in levels + # TODO: For this swapping strategy, we should really be using the adaptation from Syed et. al. (2019), + # but with the current one: shouldn't we at least divide `state.total_steps` by 2 since it will + # take use two swap-attempts before we have tried swapping all of them (in expectation). + state = swap_attempt(rng, model, sampler, state, k, sampler.adapt, state.total_steps) end - return ts + return state end diff --git a/src/swapping.jl b/src/swapping.jl index 0e1c3c1..14cebc7 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -1,81 +1,137 @@ """ - swap_betas(chain_index, k) + AbstractSwapStrategy -Swaps the `k`th and `k + 1`th temperatures. -Use `sortperm()` to convert the `chain_index` to a `Δ_index` to be used in tempering moves. +Represents a strategy for swapping between parallel chains. + +A concrete subtype is expected to implement the method [`swap_step`](@ref). """ -function swap_betas(chain_index, k) - chain_index[k], chain_index[k + 1] = chain_index[k + 1], chain_index[k] - return sortperm(chain_index), chain_index -end +abstract type AbstractSwapStrategy end -function make_tempered_loglikelihood end -function get_params end +""" + StandardSwap <: AbstractSwapStrategy +At every swap step taken, this strategy samples a single chain index `i` and proposes +a swap between chains `i` and `i + 1`. +This approach goes under a number of names, e.g. Parallel Tempering (PT) MCMC and Replica-Exchange MCMC.[^PTPH05] + +# References +[^PTPH05]: Earl, D. J., & Deem, M. W., Parallel tempering: theory, applications, and new perspectives, Physical Chemistry Chemical Physics, 7(23), 3910–3916 (2005). """ - get_tempered_loglikelihoods_and_params(model, sampler, states, k, Δ, chain_index) +struct StandardSwap <: AbstractSwapStrategy end -Temper the `model`'s density using the `k`th and `k + 1`th temperatures -selected via `Δ` and `chain_index`. Then retrieve the parameters using the chains' -current transitions extracted from the collection of `states`. """ -function get_tempered_loglikelihoods_and_params( - model, - sampler::AbstractMCMC.AbstractSampler, - states, - k::Integer, - Δ::Vector{Real}, - chain_index::Vector{<:Integer} -) - - logπk = make_tempered_loglikelihood(model, Δ[k]) - logπkp1 = make_tempered_loglikelihood(model, Δ[k + 1]) - - θk = get_params(states[chain_index[k]][1]) - θkp1 = get_params(states[chain_index[k + 1]][1]) - - return logπk, logπkp1, θk, θkp1 + RandomPermutationSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy randomly shuffles all the chain indices +and then iterates through them, proposing swaps for neighboring chains. +""" +struct RandomPermutationSwap <: AbstractSwapStrategy end + + +""" + NonReversibleSwap <: AbstractSwapStrategy + +At every swap step taken, this strategy _deterministically_ traverses first the +odd chain indices, proposing swaps between neighbors, and then in the _next_ swap step +taken traverses even chain indices, proposing swaps between neighbors. + +See [^SYED19] for more on this approach. + +# References +[^SYED19]: Syed, S., Bouchard-Côté, Alexandre, Deligiannidis, G., & Doucet, A., Non-reversible Parallel Tempering: A Scalable Highly Parallel MCMC Scheme, arXiv:1905.02939, (2019). +""" +struct NonReversibleSwap <: AbstractSwapStrategy end + +""" + swap_betas!(chain_to_process, process_to_chain, k) + +Swaps the `k`th and `k + 1`th temperatures in place. +""" +function swap_betas!(chain_to_process, process_to_chain, k) + # TODO: Use BangBang's `@set!!` to also support tuples? + # Extract the process index for each of the chains. + process_for_chain_k, process_for_chain_kp1 = chain_to_process[k], chain_to_process[k + 1] + + # Switch the mapping of the `chain → process` map. + # The temperature for the k-th chain will now be moved from its current process + # to the process for the (k + 1)-th chain, and vice versa. + chain_to_process[k], chain_to_process[k + 1] = process_for_chain_kp1, process_for_chain_k + + # Swap the mapping of the `process → chain` map. + # The process that used to have the k-th chain, now has the (k+1)-th chain, and vice versa. + process_to_chain[process_for_chain_k], process_to_chain[process_for_chain_kp1] = k + 1, k + return chain_to_process, process_to_chain end """ - swap_acceptance_pt(logπk, logπkp1, θk, θkp1) + compute_tempered_logdensities(model, sampler, transition, transition_other, β) + compute_tempered_logdensities(model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other) + +Return `(logπ(transition, β), logπ(transition_other, β))` where `logπ(x, β)` denotes the +log-density for `model` with inverse-temperature `β`. +""" +function compute_tempered_logdensities(model, sampler, sampler_other, transition, transition_other, state, state_other, β, β_other) + return compute_tempered_logdensities(model, sampler, transition, transition_other, β) +end + +""" + swap_acceptance_pt(logπk, logπkp1) Calculates and returns the swap acceptance ratio for swapping the temperature of two chains. Using tempered likelihoods `logπk` and `logπkp1` at the chains' -current state parameters `θk` and `θkp1`. +current state parameters. """ -function swap_acceptance_pt(logπk, logπkp1, θk, θkp1) - return min( - 1, - exp(logπkp1(θk) + logπk(θkp1)) / exp(logπk(θk) + logπkp1(θkp1)) - # exp(abs(βk - βkp1) * abs(AdvancedMH.logdensity(model, samplek) - AdvancedMH.logdensity(model, samplekp1))) - ) +function swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) + return (logπkp1_θk + logπk_θkp1) - (logπk_θk + logπkp1_θkp1) end """ - swap_attempt(model, sampler, states, k, Δ, Δ_index) + swap_attempt(rng, model, sampler, state, k, adapt) Attempt to swap the temperatures of two chains by tempering the densities and calculating the swap acceptance ratio; then swapping if it is accepted. """ -function swap_attempt(model, sampler, ts, k, adapt, n) - - logπk, logπkp1, θk, θkp1 = get_tempered_loglikelihoods_and_params(model, sampler, ts.states, k, ts.Δ, ts.chain_index) +function swap_attempt(rng, model, sampler, state, k, adapt, total_steps) + # TODO: Allow arbitrary `k` rather than just `k + 1`. + # Extract the relevant transitions. + samplerk = sampler_for_chain(sampler, state, k) + samplerkp1 = sampler_for_chain(sampler, state, k + 1) + transitionk = transition_for_chain(state, k) + transitionkp1 = transition_for_chain(state, k + 1) + statek = state_for_chain(state, k) + statekp1 = state_for_chain(state, k + 1) + βk = β_for_chain(state, k) + βkp1 = β_for_chain(state, k + 1) + # Evaluate logdensity for both parameters for each tempered density. + logπk_θk, logπk_θkp1 = compute_tempered_logdensities( + model, samplerk, samplerkp1, transitionk, transitionkp1, statek, statekp1, βk, βkp1 + ) + logπkp1_θkp1, logπkp1_θk = compute_tempered_logdensities( + model, samplerkp1, samplerk, transitionkp1, transitionk, statekp1, statek, βkp1, βk + ) - swap_ar = swap_acceptance_pt(logπk, logπkp1, θk, θkp1) - U = rand(Distributions.Uniform(0, 1)) - - # If the proposed temperature swap is accepted according to swap_ar and U, swap the temperatures for future steps - if U ≤ swap_ar - ts.Δ_index, ts.chain_index = swap_betas(ts.chain_index, k) + # If the proposed temperature swap is accepted according `logα`, + # swap the temperatures for future steps. + logα = swap_acceptance_pt(logπk_θk, logπk_θkp1, logπkp1_θk, logπkp1_θkp1) + should_swap = -Random.randexp(rng) ≤ logα + if should_swap + swap_betas!(state.chain_to_process, state.process_to_chain, k) end - # Adaptation steps affects Ρ and Δ, as the Ρ is adapted before a new Δ is generated and returned + # Keep track of the (log) acceptance ratios. + state.swap_acceptance_ratios[k] = logα + + # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is + # adapted before a new `inverse_temperatures` is generated and returned. if adapt - ts.Ρ, ts.Δ = adapt_ladder(ts.Ρ, ts.Δ, k, swap_ar, n) + ρs = adapt!!( + state.adaptation_states, state.inverse_temperatures, k, min(one(logα), exp(logα)) + ) + @set! state.adaptation_states = ρs + @set! state.inverse_temperatures = update_inverse_temperatures(ρs, state.inverse_temperatures) end - return ts -end \ No newline at end of file + return state +end diff --git a/src/tempered.jl b/src/tempered.jl deleted file mode 100644 index 12cbf59..0000000 --- a/src/tempered.jl +++ /dev/null @@ -1,78 +0,0 @@ -""" - struct TemperedSampler{T} <: AbstractMCMC.AbstractSampler - internal_sampler :: T - Δ :: Vector{<:Real} - Δ_init :: Vector{<:Integer} - N_swap :: Integer - swap_strategy :: Symbol - end - -A `TemperedSampler` struct wraps an `internal_sampler` (could just be an algorithm) alongside: -- A temperature ladder `Δ` containing a list of inverse temperatures `β`s -- The initial state of the tempered chains `Δ_init` in terms of which `β` each chain should begin at -- The number of steps between each temperature swap attempt `N_swap` -- The `swap_strategy` defining how these swaps should be carried out -""" -struct TemperedSampler{A} <: AbstractMCMC.AbstractSampler - internal_sampler :: A - Δ :: Vector{<:Real} - Δ_init :: Vector{<:Integer} - N_swap :: Integer - swap_strategy :: Symbol - adapt :: Bool - Ρ :: Vector{AdaptiveState} -end - - -""" - tempered(internal_sampler, Δ; ) - OR - tempered(internal_sampler, Nt; ) - -# Arguments -- `internal_sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to -- The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as: - - `Δ::Vector{<:Real}` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1 - OR - - `Nt::Integer`, specifying the number of inverse temperatures to include in a generated `Δ` -- `swap_strategy::Symbol` is the way in which temperature swaps are made, one of: - `:standard` as in original proposed algorithm, a single randomly picked swap is proposed - `:nonrev` alternate even/odd swaps as in Syed, Bouchard-Côté, Deligiannidis, Doucet, arXiv:1905.02939 such that a reverse swap cannot be made in immediate succession - `:randperm` generates a permutation in order to swap in a random order -- `Δ_init::Vector{<:Integer}` is a list containing a sequence including the integers `1:length(Δ)` and determines the starting temperature of each chain - i.e. [3, 1, 2, 4] across temperatures [1.0, 0.1, 0.01, 0.001] would mean the first chain starts at temperature 0.01, second starts at 1.0, etc. -- `N_swap::Integer` steps are carried out between each tempering swap step attempt -""" -function tempered( - internal_sampler, - Δ::Vector{<:Real}; - swap_strategy::Symbol = :standard, - kwargs... -) - return tempered(internal_sampler, check_Δ(Δ), swap_strategy; kwargs...) -end -function tempered( - internal_sampler, - Nt::Integer; - swap_strategy::Symbol = :standard, - kwargs... -) - return tempered(internal_sampler, generate_Δ(Nt, swap_strategy), swap_strategy; kwargs...) -end -function tempered( - internal_sampler, - Δ::Vector{<:Real}, - swap_strategy::Symbol; - Δ_init::Vector{<:Integer} = collect(1:length(Δ)), - N_swap::Integer = 1, - adapt::Bool = true, - adapt_target::Real = 0.234, - adapt_scale::Real = get_scaling_val(length(Δ), swap_strategy), - adapt_step::Real = 0.66, - kwargs... -) - length(Δ) > 1 || error("More than one inverse temperatures must be provided.") - N_swap >= 1 || error("This must be a positive integer.") - Ρ = init_adaptation(Δ, adapt_target, adapt_scale, adapt_step) - return TemperedSampler(internal_sampler, Δ, Δ_init, N_swap, swap_strategy, adapt, Ρ) -end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..e09628d --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,16 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +AbstractMCMC = "3.2, 4" +AdvancedMH = "0.6" +Bijectors = "0.10" +Distributions = "0.24, 0.25" +MCMCChains = "5.5" +julia = "1" diff --git a/test/compat.jl b/test/compat.jl new file mode 100644 index 0000000..4c624bc --- /dev/null +++ b/test/compat.jl @@ -0,0 +1 @@ +include("compat/advancedmh.jl") diff --git a/test/compat/advancedmh.jl b/test/compat/advancedmh.jl new file mode 100644 index 0000000..363b06e --- /dev/null +++ b/test/compat/advancedmh.jl @@ -0,0 +1,22 @@ +########################################## +### Make compatible with AdvancedMH.jl ### +########################################## +# Makes the first step possible. +# This constructs the model that are passed to the respective samplers. +function MCMCTempering.make_tempered_model(sampler, m::DensityModel, β) + return DensityModel(Base.Fix1(*, β) ∘ m.logdensity) +end + +# Now we need to make swapping possible, which requires computing +# the log density of the tempered model at the candidate states. +function MCMCTempering.compute_tempered_logdensities( + model::DensityModel, + sampler, + transition::AdvancedMH.Transition, + transition_other::AdvancedMH.Transition, + β +) + lp = β * AdvancedMH.logdensity(model, transition.params) + lp_other = β * AdvancedMH.logdensity(model, transition_other.params) + return lp, lp_other +end diff --git a/test/runtests.jl b/test/runtests.jl index 94a02c7..b08c41f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,39 +1,228 @@ using MCMCTempering using Test using Distributions -using Plots using AdvancedMH using MCMCChains +using Bijectors +using LinearAlgebra +using AbstractMCMC + +include("utils.jl") +include("compat.jl") + +""" + test_and_sample_model(model, sampler, inverse_temperatures[, swap_strategy]; kwargs...) + +Run the tempered version of `sampler` on `model` and return the resulting chain. + +Several properties of the tempered sampler are tested before returning: +- No invalid swappings has occured. +- Swaps were successfully performed at least a given portion of the time. + +# Arguments +- `model`: The model to temper and sample from. +- `sampler`: The sampler to temper and use to sample from `model`. +- `inverse_temperatures`: The inverse temperatures to for tempering.. +- `swap_strategy`: The swap strategy to use. + +# Keyword arguments +- `mean_swap_lower_bound`: A lower bound on the acceptance rate of swaps performed, e.g. if set to `0.1` then at least 10% of attempted swaps should be accepted. Defaults to `0.1`. +- `num_iterations`: The number of iterations to run the sampler for. Defaults to `2_000`. +- `swap_every`: The number of iterations between each swap attempt. Defaults to `2`. +- `adapt_target`: The target acceptance rate for the swaps. Defaults to `0.234`. +- `adapt_rtol`: The relative tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.1`. +- `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`. +- `kwargs...`: Additional keyword arguments to pass to `MCMCTempering.tempered`. +""" +function test_and_sample_model( + model, + sampler, + inverse_temperatures, + swap_strategy=MCMCTempering.StandardSwap(); + mean_swap_rate_lower_bound=0.1, + num_iterations=2_000, + swap_every=2, + adapt_target=0.234, + adapt_rtol=0.1, + adapt_atol=0.05, + kwargs... +) + # TODO: Remove this when no longer necessary. + num_iterations_tempered = Int(ceil(num_iterations * swap_every ÷ (swap_every - 1))) + + # Make the tempered sampler. + sampler_tempered = tempered( + sampler, + inverse_temperatures; + swap_strategy=swap_strategy, + swap_every=swap_every, + adapt_target=adapt_target, + kwargs... + ) + + # Store the states. + states_tempered = [] + callback = StateHistoryCallback(states_tempered) + + # Sample. + samples_tempered = AbstractMCMC.sample( + model, sampler_tempered, num_iterations_tempered; callback=callback, progress=true + ) + + # Extract the states that were swapped. + states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered) + # Swap acceptance ratios should be compared against the target acceptance in case of adaptation. + swap_acceptance_ratios = mapreduce( + collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), + vcat, + states_swapped + ) + # Check that adaptation did something useful. + if sampler_tempered.adapt + swap_acceptance_ratios = map(Base.Fix1(min, 1.0) ∘ exp, swap_acceptance_ratios) + empirical_acceptance_rate = sum(swap_acceptance_ratios) / length(swap_acceptance_ratios) + @test adapt_target ≈ empirical_acceptance_rate atol = adapt_atol rtol = adapt_rtol + + # TODO: Maybe check something related to the temperatures themselves in case of adaptation. + # E.g. converged values shouldn't all be 0 or something. + # βs = mapreduce(Base.Fix2(getproperty, :inverse_temperatures), hcat, states) + end + + # Extract the history of chain indices. + process_to_chain_history_list = map(states_tempered) do state + state.process_to_chain + end + process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) + + # Check that the swapping has been done correctly. + process_to_chain_uniqueness = map(states_tempered) do state + length(unique(state.process_to_chain)) == length(state.process_to_chain) + end + @test all(process_to_chain_uniqueness) + + # For the currently implemented strategies, the index process should not move by more than 1. + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + + chain_to_process_uniqueness = map(states_tempered) do state + length(unique(state.chain_to_process)) == length(state.chain_to_process) + end + @test all(chain_to_process_uniqueness) + + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). + swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row + # Some of the strategies performs multiple swaps in a swap-iteration, + # but we want to count the number of iterations for which we had a successful swap, + # i.e. only count non-zero elements in a row _once_. Hence the `min`. + min(1, sum(abs, row)) + end + @test sum(swap_success_indicators) ≥ (num_iterations_tempered / swap_every) * mean_swap_rate_lower_bound + + # Compare the tempered sampler to the untempered sampler. + state_tempered = states_tempered[end] + chain_tempered = AbstractMCMC.bundle_samples( + samples_tempered, model, sampler_tempered.sampler, MCMCTempering.state_for_chain(state_tempered), MCMCChains.Chains + ) + # Only pick out the samples after swapping. + # TODO: Remove this when no longer necessary. + chain_tempered = chain_tempered[swap_every:swap_every:end] + return chain_tempered +end + +function compare_chains( + chain::MCMCChains.Chains, chain_tempered::MCMCChains.Chains; + atol=1e-6, rtol=1e-6, + compare_std=true, + compare_ess=true +) + desc = describe(chain)[1].nt + desc_tempered = describe(chain_tempered)[1].nt + + # Compare the means. + @test desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol + + # Compare the std. of the chains. + if compare_std + @test desc.std ≈ desc_tempered.std atol = atol rtol = rtol + end + + # Compare the ESS. + if compare_ess + ess = MCMCChains.ess_rhat(chain).nt.ess + ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess + # HACK: Just make sure it's not doing _horrible_. Though we'd hope it would + # actually do better than the internal sampler. + @test all(ess .≥ ess_tempered .* 0.5) + end +end + @testset "MCMCTempering.jl" begin + @testset "GMM 1D" begin + num_iterations = 100_000 + gmm = MixtureModel(Normal, [(-3, 1.5), (3, 1.5), (15, 1.5), (90, 1.5)], [0.175, 0.25, 0.275, 0.3]) + logdensity(x) = logpdf(gmm, x) + + # Setup non-tempered. + model = AdvancedMH.DensityModel(logdensity) + sampler_rwmh = RWMH(Normal()) + + # Simple geometric ladder + inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.05 .^ [0, 1, 2]) + + # Run the samplers. + chain_tempered = test_and_sample_model( + model, + sampler_rwmh, + [1.0, 0.5, 0.25, 0.125], + num_iterations=num_iterations, + swap_every=2, + adapt=false, + ) - # θᵣ = [-1., 1., 2., 1., 15., 2., 90., 1.5] - # γs = [0.15, 0.25, 0.3, 0.3] + # # Compare the chains. + # compare_chains(chain, chain_tempered, atol=1e-1, compare_std=false, compare_ess=true) + end - # Δ = check_Δ([0, 0.01, 0.1, 0.25, 0.5, 1]) + @testset "MvNormal 2D" begin + d = 2 + num_iterations = 20_000 + swap_every = 2 - # modelᵣ = MixtureModel(Distributions.Normal.(eachrow(reshape(θᵣ, (2,4)))...), γs) - # # xrange = -10:0.1:100 - # # tempered_densities = pdf.(modelᵣ, xrange) .^ Δ' - # # norm_const = sum(tempered_densities[:,1]) - # # for i in 2:length(Δ) - # # tempered_densities[:,i] = (tempered_densities[:,i] ./ sum(tempered_densities[:,i])) .* norm_const - # # end - # # plot(xrange, tempered_densities, label = Δ') + μ_true = [-5.0, 5.0] + σ_true = [1.0, √(10.0)] - # data = rand(modelᵣ, 100) + logdensity(x) = logpdf(MvNormal(μ_true, Diagonal(σ_true .^ 2)), x) - # insupport(θ) = all(reshape(θ, (2,4))[2,:] .≥ 0) - # dist(θ) = MixtureModel(Distributions.Normal.(eachrow(reshape(θ, (2,4)))...), γs) - # density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf + # Sampler parameters. + inverse_temperatures = MCMCTempering.check_inverse_temperatures([0.25, 0.5, 0.75, 1.0]) - # # Construct a DensityModel. - # model = DensityModel(density) + # Construct a DensityModel. + model = DensityModel(logdensity) - # # Set up our sampler with a joint multivariate Normal proposal. - # spl = RWMH(MvNormal(8,1)) + # Set up our sampler with a joint multivariate Normal proposal. + sampler = RWMH(MvNormal(zeros(d), Diagonal(σ_true .^ 2))) + # Sample for the non-tempered model for comparison. + samples = AbstractMCMC.sample(model, sampler, num_iterations) + chain = AbstractMCMC.bundle_samples(samples, model, sampler, samples[1], MCMCChains.Chains) - # @test chain, temps = SimulatedTempering(model, spl, Δ, chain_type=Chains) - # @test chains, temps = ParallelTempering(model, spl, Δ, chain_type=Chains) + # Different swap strategies to test. + swapstrategies = [ + MCMCTempering.StandardSwap(), + MCMCTempering.RandomPermutationSwap(), + MCMCTempering.NonReversibleSwap() + ] + @testset "$(swapstrategy)" for swapstrategy in swapstrategies + chain_tempered = test_and_sample_model( + model, + sampler, + inverse_temperatures, + num_iterations=num_iterations, + swap_every=swap_every, + swapstrategy=swapstrategy, + adapt=false, + ) + compare_chains(chain, chain_tempered, rtol=0.1, compare_std=false, compare_ess=true) + end + end end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..e838850 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,28 @@ +""" + StateHistoryCallback + +Defines a callable which pushes the `state` onto the `states` container. + +Example usage when used with AbstractMCMC.jl: +```julia +# 1. Create empty container for state-history. +state_history = [] +# 2. Sample. +AbstractMCMC.sample(model, sampler, N; callback=StateHistoryCallback(state_history)) +# 3. Inspect states. +state_history +``` +""" +struct StateHistoryCallback{A,F} + states::A + selector::F +end +StateHistoryCallback() = StateHistoryCallback(Any[]) +function StateHistoryCallback(states, selector=deepcopy) + return StateHistoryCallback{typeof(states), typeof(selector)}(states, selector) +end + +function (cb::StateHistoryCallback)(rng, model, sampler, sample, state, i; kwargs...) + push!(cb.states, cb.selector(state)) + return nothing +end diff --git a/working_code/bayesode.jl b/working_code/bayesode.jl deleted file mode 100644 index 12732bc..0000000 --- a/working_code/bayesode.jl +++ /dev/null @@ -1,138 +0,0 @@ -using Turing, Distributions, DifferentialEquations - -# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics. -using MCMCChains, Plots, StatsPlots - -# Set a seed for reproducibility. -using Random -Random.seed!(14); -using MCMCTempering - - - -function lotka_volterra(du,u,p,t) - x, y = u - α, β, γ, δ = p - du[1] = (α - β*y)x # dx = - du[2] = (δ*x - γ)y # dy = -end - -p = [1.5, 1.0, 3.0, 1.0] -u0 = [1.0,1.0] -prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p) -sol = solve(prob1,Tsit5()) -plot(sol) - -sol1 = solve(prob1,Tsit5(),saveat=0.1) -odedata = Array(sol1) + 0.8 * randn(size(Array(sol1))) -plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata') - -Turing.setadbackend(:forwarddiff) - -@model function fitlv(data, prob1) - σ ~ InverseGamma(2, 3) # ~ is the tilde character - α ~ truncated(Normal(1.5,0.5),0.5,2.5) - β ~ truncated(Normal(1.0,0.5),0,2) - γ ~ truncated(Normal(3.0,0.5),1,4) - δ ~ truncated(Normal(1.0,0.5),0,2) - - p = [α,β,γ,δ] - prob = remake(prob1, p=p) - predicted = solve(prob,Tsit5(),saveat=0.1) - - for i = 1:length(predicted) - data[:,i] ~ MvNormal(predicted[i], σ) - end -end - -model = fitlv(odedata, prob1) - -# This next command runs 3 independent chains without using multithreading. -chain1_mh = sample(model, MH(), MCMCThreads(), 10000, 4) -chain1_nuts = sample(model, NUTS(.65), 1000) -# chain1_mh_test = mapreduce(c -> sample(model, NUTS(), 1000), chainscat, 1:3) - - -chain2_mh = sample(model, Tempered(MH(), 2), MCMCThreads(), 10000, 4) -chain2_nuts = sample(model, Tempered(NUTS(.65), 2), 1000) - -chain3_mh = sample(model, Tempered(MH(), 3), MCMCThreads(), 10000, 30) -chain3_nuts = sample(model, Tempered(NUTS(.65), 3), MCMCThreads(), 1000, 30) - -chain4_mh = sample(model, Tempered(MH(), 4), MCMCThreads(), 10000, 30) -chain4_nuts = sample(model, Tempered(NUTS(.65), 4), MCMCThreads(), 1000, 30) - - -plot_swaps(chain2_mh) - -plot(chain1_mh) -plot(chain2_mh) - -interchain_stats(chain1_mh) -interchain_stats(chain2_mh) - - - -# Pumas example - -function theop_model_Depots1Central1(du, u, p, t) - Depot, Central = u - Ka, CL, Vc = p - du[1] = -Ka * Depot # d Depot = - du[2] = Ka * Depot - (CL / Vc) * Central # d Central = -end - -u0 = [1.0, 1.0] -p = [2.0, 0.2, 0.8, 2.0] -prob = ODEProblem(theop_model_Depots1Central1,u0,(0.0, 10.0),p) -sol = solve(prob, Tsit5()) -plot(sol) - -@model function theopmodel_bayes(dv, SEX, WT) - - N = length(dv) - - θ ~ arraydist(truncated.(Normal.([2.0, 0.2, 0.8, 2.0], 1.0), 0.0, 10.0)) - - ωKa ~ Gamma(1.0, 0.2) - ωCL ~ Gamma(1.0, 0.2) - ωVc ~ Gamma(1.0, 0.2) - - σ ~ Gamma(1.0, 0.5) - - ηKa ~ filldist(Normal(0.0, ωKa), N) - ηCL ~ filldist(Normal(0.0, ωCL), N) - ηVc ~ filldist(Normal(0.0, ωVc), N) - - for i in 1:N - Ka = (SEX[i] == 1 ? θ[1] : θ[4]) * exp(ηKa[i]) - CL = θ[2]*(WT[i]/70) * exp(ηCL[i]) - Vc = θ[3] * exp(ηVc[i]) - - p = [Ka, CL, Vc] - prob = remake(prob1, p=p) - predicted = solve(prob, Tsit5(), saveat=0.1) - - μ[i] = predicted[i,2] / Vc - dv[i] .~ Normal.(μ, σ) - end - dv - -end - -using Pumas - - - - -# BayesNeuralODE example - -using BayesNeuralODE - -N = 1 -prior_std = likelihood_std = 1.0 -model = BNO.generate_turing_model(:spiral, N, prior_std, likelihood_std) - -bno_chain_1 = sample(model, NUTS(.6), 100) - -bno_chain_2 = sample(model, Tempered(NUTS(.6), 4), MCMCThreads(), 1000, 4) diff --git a/working_code/experiments.jl b/working_code/experiments.jl deleted file mode 100644 index 091e1cc..0000000 --- a/working_code/experiments.jl +++ /dev/null @@ -1,13 +0,0 @@ -function interchain_stats(chains) - - d = Dict() - - for param in chains.name_map.parameters - μ = std(mean(chains[param], dims=1)) - σ = std(std(chains[param], dims=1)) - push!(d, param => Dict(:μ => μ, :σ => σ)) - end - - return d - -end \ No newline at end of file diff --git a/working_code/neuralode.jl b/working_code/neuralode.jl deleted file mode 100644 index b6c1eb8..0000000 --- a/working_code/neuralode.jl +++ /dev/null @@ -1,74 +0,0 @@ -using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, AdvancedHMC -using JLD, StatsPlots, Distributions - -u0 = [2.0; 0.0] -datasize = 40 -tspan = (0.0, 1) -tsteps = range(tspan[1], tspan[2], length = datasize) - -function trueODEfunc(du, u, p, t) - true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u.^3)'true_A)' -end - -prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -mean_ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -ode_data = mean_ode_data .+ 0.1 .* randn(size(mean_ode_data)..., 30) - -####DEFINE THE NEURAL ODE##### -dudt2 = FastChain((x, p) -> x.^3, - FastDense(2, 50, relu), - FastDense(50, 2)) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) - -function predict_neuralode(p) - Array(prob_neuralode(u0, p)) -end -function loss_neuralode(p) - pred = predict_neuralode(p) - loss = sum(abs2, ode_data .- pred) - return loss, pred -end - -function l(θ) - lp = logpdf(MvNormal(zeros(length(θ) - 1), 1.0), θ[1:end-1]) - ll = sum(logpdf.(Normal.(ode_data, θ[end]), predict_neuralode(θ[1:end-1]))) - return lp + ll -end -function lp(θ) - return logpdf(MvNormal(zeros(length(θ) - 1), 1.0), θ[1:end-1]) -end -function ll(θ) - return sum(logpdf.(Normal.(ode_data, θ[end]), predict_neuralode(θ[1:end-1]))) -end - -function dldθ(θ) - x, lambda = Flux.Zygote.pullback(l,θ) - grad = first(lambda(1)) - return x, grad -end -function dlpdθ(θ) - x, lambda = Flux.Zygote.pullback(lp,θ) - grad = first(lambda(1)) - return x, grad -end -function dlldθ(θ) - x, lambda = Flux.Zygote.pullback(ll,θ) - grad = first(lambda(1)) - return x, grad -end - -init = [Float64.(prob_neuralode.p); 1.0] - -opt = DiffEqFlux.sciml_train(x -> -l(x), init, ADAM(0.05), maxiters = 1500) -pmin = opt.minimizer; -metric = DiagEuclideanMetric(length(pmin)) -h = Hamiltonian(metric, l, dldθ) -integrator = Leapfrog(find_good_stepsize(h, pmin)) -prop = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator, 10) -adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.5, prop.integrator)) - -samples, stats = sample(h, prop, pmin, 500, adaptor, 500; progress=true) - -using MCMCTempering -tempered_samples = sample() \ No newline at end of file