diff --git a/Project.toml b/Project.toml index 41e11f26..4f1bdd84 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.4" +version = "0.6.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/sampler.jl b/src/sampler.jl index 0d898719..9d46d276 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -166,6 +166,42 @@ function sample( verbose::Bool = true, progress::Bool = false, (pm_next!)::Function = pm_next!, +) where {T<:AbstractVecOrMat{<:AbstractFloat}} + # Prevent adaptor from being mutated + adaptor = deepcopy(adaptor) + # Then call sample_mutating_adaptor with the same arguments + return sample_mutating_adaptor( + rng, + h, + κ, + θ, + n_samples, + adaptor, + n_adapts; + drop_warmup = drop_warmup, + verbose = verbose, + progress = progress, + (pm_next!) = pm_next!, + ) +end + +""" + sample_mutating_adaptor(args...; kwargs...) + +The same as `sample`, but mutates the `adaptor` argument. +""" +function sample_mutating_adaptor( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + h::Hamiltonian, + κ::HMCKernel, + θ::T, + n_samples::Int, + adaptor::AbstractAdaptor = NoAdaptation(), + n_adapts::Int = min(div(n_samples, 10), 1_000); + drop_warmup = false, + verbose::Bool = true, + progress::Bool = false, + (pm_next!)::Function = pm_next!, ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results @@ -181,7 +217,6 @@ function sample( nothing time = @elapsed for i = 1:n_samples # Make a transition - # i == 2 && error(κ.τ.integrator) t = transition(rng, h, κ, t.z) # Adapt h and κ; what mutable is the adaptor tstat = stat(t) diff --git a/test/adaptation.jl b/test/adaptation.jl index 1a644c71..0be29505 100644 --- a/test/adaptation.jl +++ b/test/adaptation.jl @@ -14,7 +14,17 @@ function runnuts(ℓπ, metric; n_samples = 3_000) integrator = AdvancedHMC.make_integrator(nuts, step_size) κ = AdvancedHMC.make_kernel(nuts, integrator) adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator) - samples, stats = sample(h, κ, θ_init, n_samples, adaptor, n_adapts; verbose = false) + # Use mutating version of sample() here + samples, stats = AdvancedHMC.sample_mutating_adaptor( + rng, + h, + κ, + θ_init, + n_samples, + adaptor, + n_adapts; + verbose = false, + ) return (samples = samples, stats = stats, adaptor = adaptor) end diff --git a/test/sampler.jl b/test/sampler.jl index 109f01a4..dd439008 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -62,6 +62,7 @@ end n_steps = 10 n_samples = 22_000 n_adapts = 4_000 + @testset "$metricsym" for (metricsym, metric) in Dict( :UnitEuclideanMetric => UnitEuclideanMetric(D), :DiagEuclideanMetric => DiagEuclideanMetric(D), @@ -157,6 +158,7 @@ end end end end + @testset "drop_warmup" begin nuts = NUTS(0.8) metric = DiagEuclideanMetric(D) @@ -191,4 +193,32 @@ end @test length(samples) == n_samples @test length(stats) == n_samples end + + @testset "reproducibility" begin + # Multiple calls to sample() should yield the same results + nuts = NUTS(0.8) + metric = DiagEuclideanMetric(D) + h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) + integrator = Leapfrog(ϵ) + κ = AdvancedHMC.make_kernel(nuts, integrator) + adaptor = AdvancedHMC.make_adaptor(nuts, metric, integrator) + + all_samples = [] + for i = 1:5 + samples, stats = sample( + Random.MersenneTwister(42), + h, + κ, + θ_init, + 100, # n_samples -- don't need so many + adaptor, + 50, # n_adapts -- likewise + verbose = false, + progress = false, + drop_warmup = true, + ) + push!(all_samples, samples) + end + @test all(map(s -> s ≈ all_samples[1], all_samples[2:end])) + end end