Skip to content

Commit

Permalink
ReadMe for consctructors (#329)
Browse files Browse the repository at this point in the history
* bug + docs

* read me

* format

* Update abstractmcmc.jl

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* API

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* Hong s comments

* Fix typo and simplify arguments (#331)

* Removed `n_adapts` from sampler constructors and some fixes.  (#333)

* Update README.md

* Update README.md

Co-authored-by: Tor Erlend Fjelde <[email protected]>

---------

Co-authored-by: Jaime RZ <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Minor tweaks to the metric field comments.

* Removed redundant make_metric function

* Fix typos in constructor tests

* More fixes.

* Typofix.

* More test fixes.

* Update test/constructors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* no init_e (#335)

* no init_e

* format

* bug

* Bugfix. (#337)

* Bugfix.

* Update src/constructors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/abstractmcmc.jl

* Update src/abstractmcmc.jl

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* format

* docs update for init_e

* Update constructors.jl

* Update README.md

* Update constructors.jl

* More bugfixes.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/constructors.jl

* Update src/constructors.jl

* Update README.md

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
5 people authored Jul 26, 2023
1 parent cd5136f commit e302429
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 84 deletions.
171 changes: 140 additions & 31 deletions README.md

Large diffs are not rendered by default.

50 changes: 32 additions & 18 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ end

getadaptor(state::HMCState) = state.adaptor
getmetric(state::HMCState) = state.metric

getintegrator(state::HMCState) = state.κ.τ.integrator
getintegrator(state::HMCState) = state.κ.τ.integrator

"""
Expand Down Expand Up @@ -271,47 +269,63 @@ end

#########

function make_step_size(
rng::Random.AbstractRNG,
spl::HMCSampler,
hamiltonian::Hamiltonian,
init_params,
)
return spl.κ.τ.integrator.ϵ
end

function make_step_size(
rng::Random.AbstractRNG,
spl::AbstractHMCSampler,
hamiltonian::Hamiltonian,
init_params,
)
ϵ = spl.init_ϵ
if iszero(ϵ)
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
T = sampler_eltype(spl)
ϵ = T(ϵ)
@info string("Found initial step size ", ϵ)
end
return ϵ
T = sampler_eltype(spl)
return make_step_size(rng, spl.integrator, T, hamiltonian, init_params)

end

function make_step_size(
rng::Random.AbstractRNG,
spl::HMCSampler,
integrator::AbstractIntegrator,
T::Type,
hamiltonian::Hamiltonian,
init_params,
)
return spl.κ.τ.integrator.ϵ
return integrator.ϵ
end

function make_step_size(
rng::Random.AbstractRNG,
integrator::Symbol,
T::Type,
hamiltonian::Hamiltonian,
init_params,
)
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
@info string("Found initial step size ", ϵ)
return T(ϵ)
end

make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where {T<:Real} =
JitteredLeapfrog(ϵ, T(0.1ϵ))
make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where {T<:Real} = TemperedLeapfrog(ϵ, T(1))

#########

make_metric(i...) = error("Metric $(typeof(i)) not supported.")
make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.")
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
make_metric(i::AbstractMetric, T::Type, d::Int) = i
make_metric(i::Type{AbstractMetric}, T::Type, d::Int) = i
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
Expand Down
50 changes: 20 additions & 30 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ struct HMCSampler{T<:Real} <: AbstractHMCSampler{T}
metric::AbstractMetric
"[`AbstractAdaptor`](@ref)."
adaptor::AbstractAdaptor
"Adaptation steps if any"
n_adapts::Int
end

function HMCSampler(κ, metric, adaptor; n_adapts = 0)
function HMCSampler(κ, metric, adaptor)
T = collect(typeof(metric).parameters)[1]
return HMCSampler{T}(κ, metric, adaptor, n_adapts)
return HMCSampler{T}(κ, metric, adaptor)
end

############
### NUTS ###
############
"""
NUTS(n_adapts::Int, δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0)
NUTS(δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)
No-U-Turn Sampler (NUTS) sampler.
Expand All @@ -52,7 +50,7 @@ $(FIELDS)
# Usage:
```julia
NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio 0.65.
NUTS(δ=0.65) # Use target accept ratio 0.65.
```
"""
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
Expand All @@ -62,24 +60,15 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T}
max_depth::Int
"Maximum divergence during doubling tree."
Δ_max::T
"Initial step size; 0 means it is automatically chosen."
init_ϵ::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function NUTS(
δ;
max_depth = 10,
Δ_max = 1000.0,
init_ϵ = 0.0,
integrator = :leapfrog,
metric = :diagonal,
)
function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal)
T = typeof(δ)
return NUTS(δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
return NUTS(δ, max_depth, T(Δ_max), integrator, metric)
end

###########
Expand All @@ -97,29 +86,32 @@ $(FIELDS)
# Usage:
```julia
HMC(init_ϵ=0.05, n_leapfrog=10)
HMC(10, integrator = Leapfrog(0.05), metric = :diagonal)
```
"""
struct HMC{T<:Real} <: AbstractHMCSampler{T}
"Initial step size; 0 means automatically searching using a heuristic procedure."
init_ϵ::T
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function HMC(init_ϵ, n_leapfrog; integrator = :leapfrog, metric = :diagonal)
return HMC(init_ϵ, n_leapfrog, integrator, metric)
function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal)
if integrator isa Symbol
T = typeof(0.0) # current default float type
else
T = integrator_eltype(integrator)
end
return HMC{T}(n_leapfrog, integrator, metric)
end

#############
### HMCDA ###
#############
"""
HMCDA(n_adapts::Int, δ::Real, λ::Real; ϵ::Real=0)
HMCDA(δ::Real, λ::Real; ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)
Hamiltonian Monte Carlo sampler with Dual Averaging algorithm.
Expand All @@ -130,7 +122,7 @@ $(FIELDS)
# Usage:
```julia
HMCDA(n_adapts=200, δ=0.65, λ=0.3)
HMCDA(δ=0.65, λ=0.3)
```
For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1111.4246)):
Expand All @@ -144,16 +136,14 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
δ::T
"Target leapfrog length."
λ::T
"Initial step size; 0 means automatically searching using a heuristic procedure."
init_ϵ::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
δ, λ = promote(δ, λ)
T = typeof(δ)
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
return HMCDA(δ, T(λ), integrator, metric)
end
3 changes: 2 additions & 1 deletion src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
ϵ::T
end
Base.show(io::IO, l::Leapfrog) = print(io, "Leapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)))")
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T

### Jittering

Expand Down Expand Up @@ -131,7 +132,7 @@ function _jitter(
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}}
ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand(rng) .- 1))
return @set lf.ϵ = ϵ
return @set lf.ϵ = FT.(ϵ)
end

jitter(rng::AbstractRNG, lf::JitteredLeapfrog) = _jitter(rng, lf)
Expand Down
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include("common.jl")
θ_init = randn(rng, 2)

nuts = NUTS(0.8)
hmc = HMC(0.05, 100)
hmc = HMC(100; integrator = Leapfrog(0.05))
hmcda = HMCDA(0.8, 0.1)

integrator = Leapfrog(1e-3)
Expand Down
22 changes: 19 additions & 3 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ include("common.jl")
@testset "$T" for T in [Float32, Float64]
@testset "$(nameof(typeof(sampler)))" for (sampler, expected) in [
(
HMC(T(0.1), 25),
HMC(25, integrator = Leapfrog(T(0.1))),
(
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
),
),
# This should peform the correct promotion for the 2nd argument.
# This should perform the correct promotion for the 2nd argument.
(
HMCDA(T(0.1), 1),
HMCDA(T(0.8), 1, integrator = Leapfrog(T(0.1))),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
Expand Down Expand Up @@ -48,6 +48,22 @@ include("common.jl")
integrator_type = Leapfrog{T},
),
),
(
NUTS(T(0.8); integrator = :jitteredleapfrog),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = JitteredLeapfrog{T,T},
),
),
(
NUTS(T(0.8); integrator = :temperedleapfrog),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = TemperedLeapfrog{T,T},
),
),
]
# Make sure the sampler element type is preserved.
@test AdvancedHMC.sampler_eltype(sampler) == T
Expand Down

0 comments on commit e302429

Please sign in to comment.