From 3a4b384b656a1eb7bff4f017f1f0cbf6fc5beca7 Mon Sep 17 00:00:00 2001 From: Jaime RZ Date: Fri, 28 Jul 2023 23:34:31 +0100 Subject: [PATCH] NUTS kernel options (#342) * pass options * bump * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * tests for kernel hyperparameters * format * test * bring back all tests * bug * make_init_params bug * more tests+ init_params bug * more tests+ init_params bug * tests for bug * format * catch HMC case * Typofix. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge --- Project.toml | 2 +- src/abstractmcmc.jl | 13 ++++++++--- test/constructors.jl | 54 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 9b0679a7..7e006f1b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.5.1" +version = "0.5.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index ff0cf2a9..a2dd8b5b 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -117,7 +117,7 @@ function AbstractMCMC.step( # Define integration algorithm # Find good eps if not provided one - init_params = make_init_params(spl, logdensity, init_params) + init_params = make_init_params(rng, spl, logdensity, init_params) ϵ = make_step_size(rng, spl, hamiltonian, init_params) integrator = make_integrator(spl, ϵ) @@ -251,7 +251,12 @@ end ############# ### Utils ### ############# -function make_init_params(spl::AbstractHMCSampler, logdensity, init_params) +function make_init_params( + rng::AbstractRNG, + spl::AbstractHMCSampler, + logdensity, + init_params, +) T = sampler_eltype(spl) if init_params == nothing d = LogDensityProblems.dimension(logdensity) @@ -354,7 +359,9 @@ end ######### function make_kernel(spl::NUTS, integrator::AbstractIntegrator) - return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + return HMCKernel( + Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn(spl.max_depth, spl.Δ_max)), + ) end function make_kernel(spl::HMC, integrator::AbstractIntegrator) diff --git a/test/constructors.jl b/test/constructors.jl index b5f3a8ac..8d66a79e 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -1,9 +1,19 @@ using AdvancedHMC, AbstractMCMC, Random include("common.jl") +get_kernel_hyperparams(spl::HMC, state) = state.κ.τ.termination_criterion.L +get_kernel_hyperparams(spl::HMCDA, state) = state.κ.τ.termination_criterion.λ +get_kernel_hyperparams(spl::NUTS, state) = + state.κ.τ.termination_criterion.max_depth, state.κ.τ.termination_criterion.Δ_max + +get_kernel_hyperparamsT(spl::HMC, state) = typeof(state.κ.τ.termination_criterion.L) +get_kernel_hyperparamsT(spl::HMCDA, state) = typeof(state.κ.τ.termination_criterion.λ) +get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_criterion.Δ_max) + @testset "Constructors" begin d = 2 θ_init = randn(d) + rng = Random.default_rng() model = AbstractMCMC.LogDensityModel(ℓπ_gdemo) @testset "$T" for T in [Float32, Float64] @@ -14,6 +24,7 @@ include("common.jl") adaptor_type = NoAdaptation, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = 25, ), ), ( @@ -22,6 +33,7 @@ include("common.jl") adaptor_type = NoAdaptation, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = 25, ), ), ( @@ -30,6 +42,7 @@ include("common.jl") adaptor_type = NoAdaptation, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = 25, ), ), ( @@ -38,6 +51,7 @@ include("common.jl") adaptor_type = NoAdaptation, metric_type = UnitEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = 25, ), ), ( @@ -46,6 +60,7 @@ include("common.jl") adaptor_type = NoAdaptation, metric_type = DenseEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = 25, ), ), ( @@ -54,6 +69,7 @@ include("common.jl") adaptor_type = NesterovDualAveraging, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = one(T), ), ), # This should perform the correct promotion for the 2nd argument. @@ -63,14 +79,16 @@ include("common.jl") adaptor_type = NesterovDualAveraging, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = one(T), ), ), ( - NUTS(T(0.8)), + NUTS(T(0.8); max_depth = 20, Δ_max = T(2000.0)), ( adaptor_type = StanHMCAdaptor, metric_type = DiagEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = (20, T(2000.0)), ), ), ( @@ -79,6 +97,7 @@ include("common.jl") adaptor_type = StanHMCAdaptor, metric_type = UnitEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = (10, T(1000.0)), ), ), ( @@ -87,6 +106,7 @@ include("common.jl") adaptor_type = StanHMCAdaptor, metric_type = DenseEuclideanMetric{T}, integrator_type = Leapfrog{T}, + kernel_hp = (10, T(1000.0)), ), ), ( @@ -95,6 +115,7 @@ include("common.jl") adaptor_type = StanHMCAdaptor, metric_type = DiagEuclideanMetric{T}, integrator_type = JitteredLeapfrog{T,T}, + kernel_hp = (10, T(1000.0)), ), ), ( @@ -103,6 +124,7 @@ include("common.jl") adaptor_type = StanHMCAdaptor, metric_type = DiagEuclideanMetric{T}, integrator_type = TemperedLeapfrog{T,T}, + kernel_hp = (10, T(1000.0)), ), ), ] @@ -110,7 +132,6 @@ include("common.jl") @test AdvancedHMC.sampler_eltype(sampler) == T # Step. - rng = Random.default_rng() transition, state = AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init) @@ -126,6 +147,35 @@ include("common.jl") @test AdvancedHMC.getmetric(state) isa expected.metric_type @test AdvancedHMC.getintegrator(state) isa expected.integrator_type @test AdvancedHMC.getadaptor(state) isa expected.adaptor_type + + # Verify that the kernel is receiving the hyperparameters + @test get_kernel_hyperparams(sampler, state) == expected.kernel_hp + if typeof(sampler) <: HMC + @test get_kernel_hyperparamsT(sampler, state) == Int64 + else + @test get_kernel_hyperparamsT(sampler, state) == T + end end end end + +@testset "Utils" begin + @testset "init_params" begin + d = 2 + θ_init = randn(d) + rng = Random.default_rng() + model = AbstractMCMC.LogDensityModel(ℓπ_gdemo) + logdensity = model.logdensity + spl = NUTS(0.8) + T = AdvancedHMC.sampler_eltype(spl) + + metric = make_metric(spl, logdensity) + hamiltonian = Hamiltonian(metric, model) + + init_params1 = make_init_params(rng, spl, logdensity, nothing) + @test typeof(init_params1) == Vector{T} + @test length(init_params1) == d + init_params2 = make_init_params(rng, spl, logdensity, θ_init) + @test init_params2 === θ_init + end +end