Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReadMe for consctructors #329

Merged
merged 33 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e7489a6
bug + docs
JaimeRZP Jul 21, 2023
bca300c
read me
JaimeRZP Jul 21, 2023
1c698a3
format
JaimeRZP Jul 21, 2023
4844619
Update abstractmcmc.jl
yebai Jul 21, 2023
108f98c
Apply suggestions from code review
JaimeRZP Jul 21, 2023
75333d5
API
JaimeRZP Jul 21, 2023
cdc8645
Merge branch 'ReadMe_constructors' of https://github.com/TuringLang/A…
JaimeRZP Jul 21, 2023
a7055cd
Apply suggestions from code review
JaimeRZP Jul 25, 2023
7085642
Apply suggestions from code review
JaimeRZP Jul 25, 2023
8e62dc9
Hong s comments
JaimeRZP Jul 25, 2023
306ad4d
Fix typo and simplify arguments (#331)
yebai Jul 25, 2023
021de4f
Removed `n_adapts` from sampler constructors and some fixes. (#333)
yebai Jul 26, 2023
69fb3ee
Merge branch 'master' into ReadMe_constructors
JaimeRZP Jul 26, 2023
f43ac6b
Merge branch 'master' into ReadMe_constructors
JaimeRZP Jul 26, 2023
c207e81
Minor tweaks to the metric field comments.
yebai Jul 26, 2023
4e9ed26
Removed redundant make_metric function
yebai Jul 26, 2023
0ef0ccf
Fix typos in constructor tests
yebai Jul 26, 2023
b002f85
More fixes.
yebai Jul 26, 2023
cc0ae86
Typofix.
yebai Jul 26, 2023
a5c4983
More test fixes.
yebai Jul 26, 2023
032cf35
Update test/constructors.jl
yebai Jul 26, 2023
3f7fbcb
no init_e (#335)
JaimeRZP Jul 26, 2023
1981781
Merge branch 'master' into ReadMe_constructors
JaimeRZP Jul 26, 2023
5e46591
format
JaimeRZP Jul 26, 2023
795fefb
docs update for init_e
JaimeRZP Jul 26, 2023
f5679ec
Update constructors.jl
yebai Jul 26, 2023
d0673a2
Update README.md
yebai Jul 26, 2023
4e90876
Update constructors.jl
yebai Jul 26, 2023
0857774
More bugfixes.
yebai Jul 26, 2023
f27f5a7
Apply suggestions from code review
yebai Jul 26, 2023
6cd1d40
Update src/constructors.jl
yebai Jul 26, 2023
70cd531
Update src/constructors.jl
yebai Jul 26, 2023
b05759d
Update README.md
yebai Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 141 additions & 31 deletions README.md

Large diffs are not rendered by default.

47 changes: 31 additions & 16 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,47 +271,62 @@ end

#########

function make_step_size(
rng::Random.AbstractRNG,
spl::HMCSampler,
hamiltonian::Hamiltonian,
init_params,
)
return spl.κ.τ.integrator.ϵ
end

function make_step_size(
rng::Random.AbstractRNG,
spl::AbstractHMCSampler,
hamiltonian::Hamiltonian,
init_params,
)
ϵ = spl.init_ϵ
if iszero(ϵ)
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
T = sampler_eltype(spl)
ϵ = T(ϵ)
@info string("Found initial step size ", ϵ)
end
return ϵ
T = sampler_eltype(spl)
return make_step_size(rng, spl.integrator, T, hamiltonian, init_params)

end

function make_step_size(
rng::Random.AbstractRNG,
spl::HMCSampler,
integrator::AbstractIntegrator,
T::Type,
hamiltonian::Hamiltonian,
init_params,
)
return spl.κ.τ.integrator.ϵ
return integrator.ϵ
end

function make_step_size(
rng::Random.AbstractRNG,
integrator::Symbol,
T::Type,
hamiltonian::Hamiltonian,
init_params,
)
ϵ = find_good_stepsize(rng, hamiltonian, init_params)
@info string("Found initial step size ", ϵ)
return T(ϵ)
end

make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Type{<:AbstractIntegrator}, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(i...) = error("Integrator $(typeof(i)) not supported.")
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ, 0.1ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ, 1.0)

#########

make_metric(i...) = error("Metric $(typeof(i)) not supported.")
make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.")
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
make_metric(i::AbstractMetric, T::Type, d::Int) = i
make_metric(i::Type{AbstractMetric}, T::Type, d::Int) = i
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
Expand Down
50 changes: 20 additions & 30 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ struct HMCSampler{T<:Real} <: AbstractHMCSampler{T}
metric::AbstractMetric
"[`AbstractAdaptor`](@ref)."
adaptor::AbstractAdaptor
"Adaptation steps if any"
n_adapts::Int
end

function HMCSampler(κ, metric, adaptor; n_adapts = 0)
function HMCSampler(κ, metric, adaptor)
T = collect(typeof(metric).parameters)[1]
return HMCSampler{T}(κ, metric, adaptor, n_adapts)
return HMCSampler{T}(κ, metric, adaptor)
end

############
### NUTS ###
############
"""
NUTS(n_adapts::Int, δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0)
NUTS(δ::Real; max_depth::Int=10, Δ_max::Real=1000, init_ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)

No-U-Turn Sampler (NUTS) sampler.

Expand All @@ -52,7 +50,7 @@ $(FIELDS)
# Usage:

```julia
NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio 0.65.
NUTS(δ=0.65) # Use target accept ratio 0.65.
```
"""
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
Expand All @@ -62,24 +60,15 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T}
max_depth::Int
"Maximum divergence during doubling tree."
Δ_max::T
"Initial step size; 0 means it is automatically chosen."
init_ϵ::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function NUTS(
δ;
max_depth = 10,
Δ_max = 1000.0,
init_ϵ = 0.0,
integrator = :leapfrog,
metric = :diagonal,
)
function NUTS(δ; max_depth = 10, Δ_max = 1000.0, integrator = :leapfrog, metric = :diagonal)
T = typeof(δ)
return NUTS(δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
return NUTS(δ, max_depth, T(Δ_max), integrator, metric)
end

###########
Expand All @@ -97,29 +86,32 @@ $(FIELDS)
# Usage:

```julia
HMC(init_ϵ=0.05, n_leapfrog=10)
HMC(init_ϵ=0.05, n_leapfrog=10, integrator = :leapfrog, metric = :diagonal)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```
"""
struct HMC{T<:Real} <: AbstractHMCSampler{T}
"Initial step size; 0 means automatically searching using a heuristic procedure."
init_ϵ::T
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function HMC(init_ϵ, n_leapfrog; integrator = :leapfrog, metric = :diagonal)
return HMC(init_ϵ, n_leapfrog, integrator, metric)
function HMC(n_leapfrog; integrator = :leapfrog, metric = :diagonal)
if integrator isa Symbol
T = typeof(0.0) # current default float type
else
T = integrator_eltype(integrator)
end
return HMC{T}(n_leapfrog, integrator, metric)
end

#############
### HMCDA ###
#############
"""
HMCDA(n_adapts::Int, δ::Real, λ::Real; ϵ::Real=0)
HMCDA(δ::Real, λ::Real; ϵ::Real=0, integrator = :leapfrog, metric = :diagonal)

Hamiltonian Monte Carlo sampler with Dual Averaging algorithm.

Expand All @@ -130,7 +122,7 @@ $(FIELDS)
# Usage:

```julia
HMCDA(n_adapts=200, δ=0.65, λ=0.3)
HMCDA(δ=0.65, λ=0.3)
```

For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1111.4246)):
Expand All @@ -144,16 +136,14 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
δ::T
"Target leapfrog length."
λ::T
"Initial step size; 0 means automatically searching using a heuristic procedure."
init_ϵ::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
"Choice of initial metric, specified using a `Symbol` or `AbstractMetric`. The metric type will be preserved during adaption."
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
end

function HMCDA(δ, λ; init_ϵ = 0, integrator = :leapfrog, metric = :diagonal)
δ, λ = promote(δ, λ)
T = typeof(δ)
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
return HMCDA(δ, T(λ), integrator, metric)
end
1 change: 1 addition & 0 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
ϵ::T
end
Base.show(io::IO, l::Leapfrog) = print(io, "Leapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)))")
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T

### Jittering

Expand Down
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include("common.jl")
θ_init = randn(rng, 2)

nuts = NUTS(0.8)
hmc = HMC(0.05, 100)
hmc = HMC(100; integrator = Leapfrog(0.05))
hmcda = HMCDA(0.8, 0.1)

integrator = Leapfrog(1e-3)
Expand Down
21 changes: 19 additions & 2 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include("common.jl")
@testset "$T" for T in [Float32, Float64]
@testset "$(nameof(typeof(sampler)))" for (sampler, expected) in [
(
HMC(T(0.1), 25),
HMC(25),
(
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
Expand Down Expand Up @@ -48,6 +48,23 @@ include("common.jl")
integrator_type = Leapfrog{T},
),
),
(
NUTS(T(0.8); integrator = :jitteredleapfrog),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = JitterdLeapfrog{T},
),
),
(
NUTS(T(0.8); integrator = :temperedleapfrog),
JaimeRZP marked this conversation as resolved.
Show resolved Hide resolved
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = TemperedLeapfrog{T},
),
),

JaimeRZP marked this conversation as resolved.
Show resolved Hide resolved
]
# Make sure the sampler element type is preserved.
@test AdvancedHMC.sampler_eltype(sampler) == T
Expand All @@ -56,7 +73,7 @@ include("common.jl")
rng = Random.default_rng()
transition, state =
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)

JaimeRZP marked this conversation as resolved.
Show resolved Hide resolved
# Verify that the types are preserved in the transition.
@test eltype(transition.z.θ) == T
@test eltype(transition.z.r) == T
Expand Down