diff --git a/Project.toml b/Project.toml index 89250b2c..cbbf6ef2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.2.28" +version = "0.3.0" [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34" @@ -17,6 +18,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] +AbstractMCMC = "3" ArgCheck = "1, 2" DocStringExtensions = "0.8" InplaceOps = "0.3" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 9c90bff0..b7814c66 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -5,6 +5,7 @@ const DEBUG = convert(Bool, parse(Int, get(ENV, "DEBUG_AHMC", "0"))) using Statistics: mean, var, middle using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling using StatsFuns: logaddexp, logsumexp +import Random using Random: GLOBAL_RNG, AbstractRNG using ProgressMeter: ProgressMeter using UnPack: @unpack @@ -16,6 +17,8 @@ using ArgCheck: @argcheck using DocStringExtensions +import AbstractMCMC + import StatsBase: sample include("utilities.jl") @@ -128,6 +131,9 @@ include("diagnosis.jl") include("sampler.jl") export sample +include("abstractmcmc.jl") +export DifferentiableDensityModel + include("contrib/ad.jl") ### Init diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl new file mode 100644 index 00000000..eb50cd8e --- /dev/null +++ b/src/abstractmcmc.jl @@ -0,0 +1,293 @@ +""" + HMCSampler + +A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl. + +# Fields + +$(FIELDS) + +# Notes + +Note that all the fields have the prefix `initial_` to indicate +that these will not necessarily correspond to the `kernel`, `metric`, +and `adaptor` after sampling. + +To access the updated fields use the resulting [`HMCState`](@ref). +""" +struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler + "Initial [`AbstractMCMCKernel`](@ref)." + initial_kernel::K + "Initial [`AbstractMetric`](@ref)." + initial_metric::M + "Initial [`AbstractAdaptor`](@ref)." + initial_adaptor::A +end +HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation()) + +""" + DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ) + DifferentiableDensityModel(ℓπ, m::Module) + +A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density. + +If a module `m` is given as the second argument, then `m` is assumed to be an +automatic-differentiation package and this will be used to compute the gradients. + +Note that the module `m` must be imported before usage, e.g. +```julia +using Zygote: Zygote +model = DifferentiableDensityModel(ℓπ, Zygote) +``` +results in a `model` which will use Zygote.jl as its AD-backend. + +# Fields +$(FIELDS) +""" +struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel + "Log-density. Maps `AbstractArray` to value of the log-density." + ℓπ::Tlogπ + "Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point." + ∂ℓπ∂θ::T∂logπ∂θ +end + +struct DummyMetric <: AbstractMetric end +function DifferentiableDensityModel(ℓπ, m::Module) + h = Hamiltonian(DummyMetric(), ℓπ, m) + return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ) +end + +""" + HMCState + +Represents the state of a [`HMCSampler`](@ref). + +# Fields + +$(FIELDS) + +""" +struct HMCState{ + TTrans<:Transition, + TMetric<:AbstractMetric, + TKernel<:AbstractMCMCKernel, + TAdapt<:Adaptation.AbstractAdaptor +} + "Index of current iteration." + i::Int + "Current [`Transition`](@ref)." + transition::TTrans + "Current [`AbstractMetric`](@ref), possibly adapted." + metric::TMetric + "Current [`AbstractMCMCKernel`](@ref)." + κ::TKernel + "Current [`AbstractAdaptor`](@ref)." + adaptor::TAdapt +end + +""" + $(TYPEDSIGNATURES) + +A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref). +""" +function AbstractMCMC.sample( + model::DifferentiableDensityModel, + kernel::AbstractMCMCKernel, + metric::AbstractMetric, + adaptor::AbstractAdaptor, + N::Integer; + kwargs... +) + return AbstractMCMC.sample(Random.GLOBAL_RNG, model, kernel, metric, adaptor, N; kwargs...) +end + +function AbstractMCMC.sample( + rng::Random.AbstractRNG, + model::DifferentiableDensityModel, + kernel::AbstractMCMCKernel, + metric::AbstractMetric, + adaptor::AbstractAdaptor, + N::Integer; + progress = true, + verbose = false, + callback = nothing, + kwargs... +) + sampler = HMCSampler(kernel, metric, adaptor) + if callback === nothing + callback = HMCProgressCallback(N, progress = progress, verbose = verbose) + progress = false # don't use AMCMC's progress-funtionality + end + + return AbstractMCMC.mcmcsample( + rng, model, sampler, N; + progress = progress, + verbose = verbose, + callback = callback, + kwargs... + ) +end + +function AbstractMCMC.sample( + model::DifferentiableDensityModel, + kernel::AbstractMCMCKernel, + metric::AbstractMetric, + adaptor::AbstractAdaptor, + parallel::AbstractMCMC.AbstractMCMCParallel, + N::Integer, + nchains::Integer; + kwargs... +) + return AbstractMCMC.sample( + Random.GLOBAL_RNG, model, kernel, metric, adaptor, N, nchains; + kwargs... + ) +end + +function AbstractMCMC.sample( + rng::Random.AbstractRNG, + model::DifferentiableDensityModel, + kernel::AbstractMCMCKernel, + metric::AbstractMetric, + adaptor::AbstractAdaptor, + parallel::AbstractMCMC.AbstractMCMCParallel, + N::Integer, + nchains::Integer; + progress = true, + verbose = false, + callback = nothing, + kwargs... +) + sampler = HMCSampler(kernel, metric, adaptor) + if callback === nothing + callback = HMCProgressCallback(N, progress = progress, verbose = verbose) + progress = false # don't use AMCMC's progress-funtionality + end + + return AbstractMCMC.mcmcsample( + rng, model, sampler, parallel, N, nchains; + progress = progress, + verbose = verbose, + callback = callback, + kwargs... + ) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + model::DifferentiableDensityModel, + spl::HMCSampler; + init_params = nothing, + kwargs... +) + metric = spl.initial_metric + κ = spl.initial_kernel + adaptor = spl.initial_adaptor + + if init_params === nothing + init_params = randn(size(metric, 1)) + end + + # Construct the hamiltonian using the initial metric + hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ) + + # Get an initial sample. + h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params) + + # Compute next transition and state. + state = HMCState(0, t, h.metric, κ, adaptor) + + # Take actual first step. + return AbstractMCMC.step(rng, model, spl, state; kwargs...) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + model::DifferentiableDensityModel, + spl::HMCSampler, + state::HMCState; + nadapts::Int = 0, + kwargs... +) + # Get step size + @debug "current ϵ" getstepsize(spl, state) + + # Compute transition. + i = state.i + 1 + t_old = state.transition + adaptor = state.adaptor + κ = state.κ + metric = state.metric + + # Reconstruct hamiltonian. + h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ) + + # Make new transition. + t = transition(rng, h, κ, t_old.z) + + # Adapt h and spl. + tstat = stat(t) + h, κ, isadapted = adapt!(h, κ, adaptor, i, nadapts, t.z.θ, tstat.acceptance_rate) + tstat = merge(tstat, (is_adapt=isadapted,)) + + # Compute next transition and state. + newstate = HMCState(i, t, h.metric, κ, adaptor) + + # Return `Transition` with additional stats added. + return Transition(t.z, tstat), newstate +end + + +################ +### Callback ### +################ +""" + HMCProgressCallback + +A callback to be used with AbstractMCMC.jl's interface, replicating the +logging behavior of the non-AbstractMCMC [`sample`](@ref). + +# Fields +$(FIELDS) +""" +struct HMCProgressCallback{P} + "`Progress` meter from ProgressMeters.jl." + pm::P + "Specifies whether or not to use display a progress bar." + progress::Bool + "If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation." + verbose::Bool +end + +function HMCProgressCallback(n_samples; progress=true, verbose=false) + pm = progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing + HMCProgressCallback(pm, progress, verbose) +end + +function (cb::HMCProgressCallback)( + rng, model, spl, t, state, i; + nadapts = 0, + kwargs... +) + progress = cb.progress + verbose = cb.verbose + pm = cb.pm + + metric = state.metric + adaptor = state.adaptor + κ = state.κ + tstat = t.stat + isadapted = tstat.is_adapt + + # Update progress meter + if progress + # Do include current iteration and mass matrix + pm_next!( + pm, + (iterations=i, tstat..., mass_matrix=metric) + ) + # Report finish of adapation + elseif verbose && isadapted && i == nadapts + @info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric + end +end diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index a2bd55de..497df225 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -47,11 +47,11 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat}, V<:DualValue} @warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ)) # NOTE eltype has to be inlined to avoid type stability issue; see #267 ℓπ = DualValue( - map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value), + map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value), ℓπ.gradient ) ℓκ = DualValue( - map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value), + map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value), ℓκ.gradient ) end diff --git a/src/sampler.jl b/src/sampler.jl index a2de7515..39e8e8eb 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -33,7 +33,6 @@ end ## ## Interface functions ## - function sample_init( rng::Union{AbstractRNG, AbstractVector{<:AbstractRNG}}, h::Hamiltonian, @@ -143,7 +142,6 @@ sample( verbose::Bool=true, progress::Bool=false ) - Sample `n_samples` samples using the proposal `κ` under Hamiltonian `h`. - The randomness is controlled by `rng`. - If `rng` is not provided, `GLOBAL_RNG` will be used. diff --git a/src/trajectory.jl b/src/trajectory.jl index 2d748204..f943f591 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -213,7 +213,7 @@ nsteps(τ::Trajectory{TS, I, TC}) where {TS, I, TC<:FixedIntegrationTime} = ## Kernel interface ## -struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel +struct HMCKernel{R, T<:Trajectory} <: AbstractMCMCKernel refreshment::R τ::T end diff --git a/test/Project.toml b/test/Project.toml index 8927e4ba..20d589ee 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl new file mode 100644 index 00000000..61ae4f0e --- /dev/null +++ b/test/abstractmcmc.jl @@ -0,0 +1,38 @@ +using Test, Random, AdvancedHMC, ForwardDiff, AbstractMCMC +using Statistics: mean +include("common.jl") + +@testset "`gdemo`" begin + rng = MersenneTwister(0) + + n_samples = 5_000 + n_adapts = 5_000 + + θ_init = randn(rng, 2) + + model = AdvancedHMC.DifferentiableDensityModel(ℓπ_gdemo, ForwardDiff) + init_eps = Leapfrog(1e-3) + κ = NUTS(init_eps) + metric = DiagEuclideanMetric(2) + adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, κ.τ.integrator)) + + samples = AbstractMCMC.sample( + model, κ, metric, adaptor, n_adapts + n_samples; + nadapts = n_adapts, + init_params = θ_init, + progress=false, + verbose=false + ); + + # Transform back to original space. + # NOTE: We're not correcting for the `logabsdetjac` here since, but + # we're only interested in the mean it doesn't matter. + for t in samples + t.z.θ .= invlink_gdemo(t.z.θ) + end + m_est = mean(samples[n_adapts + 1:end]) do t + t.z.θ + end + + @test m_est ≈ [49 / 24, 7 / 6] atol=RNDATOL +end diff --git a/test/runtests.jl b/test/runtests.jl index 873b682b..2d016471 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,7 @@ const GROUP = get(ENV, "AHMC_TEST_GROUP", "All") "sampler-vec", "demo", "models", + "abstractmcmc" ] if CUDA.functional()