Skip to content

Commit

Permalink
More bugfixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Jul 26, 2023
1 parent 4e90876 commit 0857774
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
6 changes: 2 additions & 4 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ end

getadaptor(state::HMCState) = state.adaptor
getmetric(state::HMCState) = state.metric

getintegrator(state::HMCState) = state.κ.τ.integrator
getintegrator(state::HMCState) = state.κ.τ.integrator

"""
Expand Down Expand Up @@ -319,8 +317,8 @@ make_integrator(i::AbstractIntegrator, ϵ::Real) = i
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::Real) = JitteredLeapfrog(ϵ, 0.1ϵ)
make_integrator(i::Val{:temperedleapfrog}, ϵ::Real) = TemperedLeapfrog(ϵ, 1.0)
make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where T<:Real = JitteredLeapfrog(ϵ, T(0.1ϵ))
make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where T<:Real = TemperedLeapfrog(ϵ, T(1))

#########

Expand Down
2 changes: 1 addition & 1 deletion src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ function _jitter(
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}}
ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand(rng) .- 1))
return @set lf.ϵ = ϵ
return @set lf.ϵ = FT.(ϵ)
end

jitter(rng::AbstractRNG, lf::JitteredLeapfrog) = _jitter(rng, lf)
Expand Down
4 changes: 2 additions & 2 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ include("common.jl")
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = JitterdLeapfrog{T},
integrator_type = JitteredLeapfrog{T, T},
),
),
(
NUTS(T(0.8); integrator = :temperedleapfrog),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = TemperedLeapfrog{T},
integrator_type = TemperedLeapfrog{T, T},
),
),
]
Expand Down

0 comments on commit 0857774

Please sign in to comment.