From c65da1f796732afd83acbbc0d203986e5742e22a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 5 Nov 2024 01:34:12 +0000 Subject: [PATCH] Deepcopy adaptor before starting sampling This avoids the unintuitive behaviour seen in #379 --- src/sampler.jl | 36 ++++++++++++++++++++++++++++++++++++ test/adaptation.jl | 12 +++++++++++- test/sampler.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 0d898719..d2053197 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -166,6 +166,42 @@ function sample( verbose::Bool = true, progress::Bool = false, (pm_next!)::Function = pm_next!, +) where {T<:AbstractVecOrMat{<:AbstractFloat}} + # Prevent adaptor from being mutated + adaptor = deepcopy(adaptor) + # Then call sample_mutating_adaptor with the same arguments + return sample_mutating_adaptor( + rng, + h, + κ, + θ, + n_samples, + adaptor, + n_adapts; + drop_warmup = drop_warmup, + verbose = verbose, + progress = progress, + (pm_next!) = pm_next!, + ) +end + +""" + sample_mutating_adaptor(args...; kwargs...) + +The same as `sample`, but mutates the `adaptor` argument. +""" +function sample_mutating_adaptor( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + h::Hamiltonian, + κ::HMCKernel, + θ::T, + n_samples::Int, + adaptor::AbstractAdaptor = NoAdaptation(), + n_adapts::Int = min(div(n_samples, 10), 1_000); + drop_warmup = false, + verbose::Bool = true, + progress::Bool = false, + (pm_next!)::Function = pm_next!, ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/test/adaptation.jl b/test/adaptation.jl index 1a644c71..0be29505 100644 --- a/test/adaptation.jl +++ b/test/adaptation.jl @@ -14,7 +14,17 @@ function runnuts(ℓπ, metric; n_samples = 3_000) integrator = AdvancedHMC.make_integrator(nuts, step_size) κ = AdvancedHMC.make_kernel(nuts, integrator) adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator) - samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose = false) + # Use mutating version of sample() here + samples, stats = AdvancedHMC.sample_mutating_adaptor( + rng, + h, + κ, + θ_init, + n_samples, + adaptor, + n_adapts; + verbose = false, + ) return (samples = samples, stats = stats, adaptor = adaptor) end diff --git a/test/sampler.jl b/test/sampler.jl index 109f01a4..dd439008 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -62,6 +62,7 @@ end n_steps = 10 n_samples = 22_000 n_adapts = 4_000 + @testset "$metricsym" for (metricsym, metric) in Dict( :UnitEuclideanMetric => UnitEuclideanMetric(D), :DiagEuclideanMetric => DiagEuclideanMetric(D), @@ -157,6 +158,7 @@ end end end end + @testset "drop_warmup" begin nuts = NUTS(0.8) metric = DiagEuclideanMetric(D) @@ -191,4 +193,32 @@ end @test length(samples) == n_samples @test length(stats) == n_samples end + + @testset "reproducibility" begin + # Multiple calls to sample() should yield the same results + nuts = NUTS(0.8) + metric = DiagEuclideanMetric(D) + h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) + integrator = Leapfrog(ϵ) + κ = AdvancedHMC.make_kernel(nuts, integrator) + adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator) + + all_samples = [] + for i = 1:5 + samples, stats = sample( + Random.MersenneTwister(42), + h, + κ, + θ_init, + 100, # n_samples -- don't need so many + adaptor, + 50, # n_adapts -- likewise + verbose = false, + progress = false, + drop_warmup = true, + ) + push!(all_samples, samples) + end + @test all(map(s -> s ≈ all_samples[1], all_samples[2:end])) + end end