Skip to content

Commit

Permalink
NUTS kernel options (#342)
Browse files Browse the repository at this point in the history
* pass options

* bump

* Update src/abstractmcmc.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* tests for kernel hyperparameters

* format

* test

* bring back all tests

* bug

* make_init_params bug

* more tests+ init_params bug

* more tests+ init_params bug

* tests for bug

* format

* catch HMC case

* Typofix.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent 762e55f commit 3a4b384
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.5.1"
version = "0.5.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
13 changes: 10 additions & 3 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function AbstractMCMC.step(

# Define integration algorithm
# Find good eps if not provided one
init_params = make_init_params(spl, logdensity, init_params)
init_params = make_init_params(rng, spl, logdensity, init_params)
ϵ = make_step_size(rng, spl, hamiltonian, init_params)
integrator = make_integrator(spl, ϵ)

Expand Down Expand Up @@ -251,7 +251,12 @@ end
#############
### Utils ###
#############
function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
function make_init_params(
rng::AbstractRNG,
spl::AbstractHMCSampler,
logdensity,
init_params,
)
T = sampler_eltype(spl)
if init_params == nothing
d = LogDensityProblems.dimension(logdensity)
Expand Down Expand Up @@ -354,7 +359,9 @@ end
#########

function make_kernel(spl::NUTS, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
return HMCKernel(
Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn(spl.max_depth, spl.Δ_max)),
)
end

function make_kernel(spl::HMC, integrator::AbstractIntegrator)
Expand Down
54 changes: 52 additions & 2 deletions test/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

get_kernel_hyperparams(spl::HMC, state) = state.κ.τ.termination_criterion.L
get_kernel_hyperparams(spl::HMCDA, state) = state.κ.τ.termination_criterion.λ
get_kernel_hyperparams(spl::NUTS, state) =
state.κ.τ.termination_criterion.max_depth, state.κ.τ.termination_criterion.Δ_max

get_kernel_hyperparamsT(spl::HMC, state) = typeof(state.κ.τ.termination_criterion.L)
get_kernel_hyperparamsT(spl::HMCDA, state) = typeof(state.κ.τ.termination_criterion.λ)
get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_criterion.Δ_max)

@testset "Constructors" begin
d = 2
θ_init = randn(d)
rng = Random.default_rng()
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)

@testset "$T" for T in [Float32, Float64]
Expand All @@ -14,6 +24,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -22,6 +33,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -30,6 +42,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -38,6 +51,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = UnitEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -46,6 +60,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DenseEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -54,6 +69,7 @@ include("common.jl")
adaptor_type = NesterovDualAveraging,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = one(T),
),
),
# This should perform the correct promotion for the 2nd argument.
Expand All @@ -63,14 +79,16 @@ include("common.jl")
adaptor_type = NesterovDualAveraging,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = one(T),
),
),
(
NUTS(T(0.8)),
NUTS(T(0.8); max_depth = 20, Δ_max = T(2000.0)),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (20, T(2000.0)),
),
),
(
Expand All @@ -79,6 +97,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = UnitEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -87,6 +106,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DenseEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -95,6 +115,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = JitteredLeapfrog{T,T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -103,14 +124,14 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = TemperedLeapfrog{T,T},
kernel_hp = (10, T(1000.0)),
),
),
]
# Make sure the sampler element type is preserved.
@test AdvancedHMC.sampler_eltype(sampler) == T

# Step.
rng = Random.default_rng()
transition, state =
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)

Expand All @@ -126,6 +147,35 @@ include("common.jl")
@test AdvancedHMC.getmetric(state) isa expected.metric_type
@test AdvancedHMC.getintegrator(state) isa expected.integrator_type
@test AdvancedHMC.getadaptor(state) isa expected.adaptor_type

# Verify that the kernel is receiving the hyperparameters
@test get_kernel_hyperparams(sampler, state) == expected.kernel_hp
if typeof(sampler) <: HMC
@test get_kernel_hyperparamsT(sampler, state) == Int64
else
@test get_kernel_hyperparamsT(sampler, state) == T
end
end
end
end

@testset "Utils" begin
@testset "init_params" begin
d = 2
θ_init = randn(d)
rng = Random.default_rng()
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
logdensity = model.logdensity
spl = NUTS(0.8)
T = AdvancedHMC.sampler_eltype(spl)

metric = make_metric(spl, logdensity)
hamiltonian = Hamiltonian(metric, model)

init_params1 = make_init_params(rng, spl, logdensity, nothing)
@test typeof(init_params1) == Vector{T}
@test length(init_params1) == d
init_params2 = make_init_params(rng, spl, logdensity, θ_init)
@test init_params2 === θ_init
end
end

2 comments on commit 3a4b384

@yebai
Copy link
Member

@yebai yebai commented on 3a4b384 Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88575

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.2 -m "<description of version>" 3a4b384b656a1eb7bff4f017f1f0cbf6fc5beca7
git push origin v0.5.2

Please sign in to comment.