From b002f8567ba910b392031ba578b6dd3359cb8fc6 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 26 Jul 2023 11:42:33 +0100 Subject: [PATCH] More fixes. --- src/constructors.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index ccde73f9..35367963 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -28,20 +28,18 @@ struct HMCSampler{T<:Real} <: AbstractHMCSampler{T} metric::AbstractMetric "[`AbstractAdaptor`](@ref)." adaptor::AbstractAdaptor - "Adaptation steps if any" - n_adapts::Int end -function HMCSampler(κ, metric, adaptor; n_adapts = 0) +function HMCSampler(κ, metric, adaptor) T = collect(typeof(metric).parameters)[1] - return HMCSampler{T}(κ, metric, adaptor, n_adapts) + return HMCSampler{T}(κ, metric, adaptor) end ############ ### NUTS ### ############ """ - NUTS(n_adapts::Int, δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0) + NUTS(δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0, init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal) No-U-Turn Sampler (NUTS) sampler. @@ -52,7 +50,7 @@ $(FIELDS) # Usage: ```julia -NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio 0.65. +NUTS(δ=0.65) # Use target accept ratio 0.65. ``` """ struct NUTS{T<:Real} <: AbstractHMCSampler{T} @@ -97,7 +95,7 @@ $(FIELDS) # Usage: ```julia -HMC(init_ϵ=0.05, n_leapfrog=10) +HMC(init_ϵ=0.05, n_leapfrog=10, integrator = :leapfrog, metric = :diagonal) ``` """ struct HMC{T<:Real} <: AbstractHMCSampler{T} @@ -119,7 +117,7 @@ end ### HMCDA ### ############# """ - HMCDA(n_adapts::Int, δ::Real, λ::Real; ϵ::Real=0) + HMCDA(δ::Real, λ::Real; ϵ::Real=0, integrator = :leapfrog, metric = :diagonal) Hamiltonian Monte Carlo sampler with Dual Averaging algorithm. @@ -130,7 +128,7 @@ $(FIELDS) # Usage: ```julia -HMCDA(n_adapts=200, δ=0.65, λ=0.3) +HMCDA(δ=0.65, λ=0.3) ``` For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1111.4246)):