Skip to content

Commit

Permalink
Generally support container of proposals (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 28, 2021
1 parent 941c046 commit 1911b9d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 147 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.6.0"
version = "0.6.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
68 changes: 40 additions & 28 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,60 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
gradient::G
end

transition(::MALA, model, params) = GradientTransition(model, params)

# Store the new draw, its log density and its gradient
GradientTransition(model::DensityModel, params) = GradientTransition(params, logdensity_and_gradient(model, params)...)
logdensity(model::DensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModel, params)
return GradientTransition(params, logdensity_and_gradient(model, params)...)
end

function propose(
function AbstractMCMC.step(
rng::Random.AbstractRNG,
spl::MALA{<:Proposal},
model::DensityModel,
params_prev::GradientTransition
)
proposal = propose(rng, spl.proposal(params_prev.gradient), model, params_prev.params)
return GradientTransition(model, proposal)
end


function q(
spl::MALA{<:Proposal},
t::GradientTransition,
t_cond::GradientTransition
)
return q(spl.proposal(-t_cond.gradient), t.params, t_cond.params)
end

function logratio_proposal_density(
sampler::MALA{<:Proposal}, state::GradientTransition, candidate::GradientTransition
sampler::MALA,
transition_prev::GradientTransition;
kwargs...
)
return q(sampler, state, candidate) - q(sampler, candidate, state)
# Extract value and gradient of the log density of the current state.
state = transition_prev.params
logdensity_state = transition_prev.lp
gradient_logdensity_state = transition_prev.gradient

# Generate a new proposal.
proposal = sampler.proposal
candidate = propose(rng, proposal(gradient_logdensity_state), model, state)

# Compute both the value of the log density and its gradient
logdensity_candidate, gradient_logdensity_candidate = logdensity_and_gradient(
model, candidate
)

# Compute the log ratio of proposal densities.
logratio_proposal_density = q(
proposal(-gradient_logdensity_candidate), state, candidate
) - q(proposal(-gradient_logdensity_state), candidate, state)

# Compute the log acceptance probability.
logα = logdensity_candidate - logdensity_state + logratio_proposal_density

# Decide whether to return the previous params or the new one.
transition = if -Random.randexp(rng) < logα
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate)
else
transition_prev
end

return transition, transition
end

"""
logdensity_and_gradient(model::DensityModel, params)
Efficiently returns the value and gradient of the model
Return the value and gradient of the log density of the parameters `params` for the `model`.
"""
function logdensity_and_gradient(model::DensityModel, params)
res = GradientResult(params)
gradient!(res, model.logdensity, params)
return (value(res), gradient(res))
return value(res), gradient(res)
end


logdensity(model::DensityModel, t::GradientTransition) = t.lp
19 changes: 2 additions & 17 deletions src/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,8 @@ struct Ensemble{D} <: MHSampler
proposal::D
end

# Define the first sampling step.
# Return a 2-tuple consisting of the initial sample and the initial state.
# In this case they are identical.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
spl::Ensemble;
init_params = nothing,
kwargs...,
)
if init_params === nothing
transitions = propose(rng, spl, model)
else
transitions = [Transition(model, x) for x in init_params]
end

return transitions, transitions
function transition(sampler::Ensemble, model::DensityModel, params)
return [Transition(model, x) for x in params]
end

# Define the other sampling steps.
Expand Down
128 changes: 27 additions & 101 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,115 +48,38 @@ end
StaticMH(d) = MetropolisHastings(StaticProposal(d))
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))

# default function without RNG
propose(spl::MetropolisHastings, args...) = propose(Random.GLOBAL_RNG, spl, args...)

# Propose from a vector of proposals
function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:AbstractArray},
model::DensityModel
)
proposal = map(p -> propose(rng, p, model), spl.proposal)
return Transition(model, proposal)
end

function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:AbstractArray},
model::DensityModel,
params_prev::Transition
)
proposal = map(spl.proposal, params_prev.params) do p, params
propose(rng, p, model, params)
end
return Transition(model, proposal)
end

# Make a proposal from one Proposal struct.
function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:Proposal},
model::DensityModel
)
proposal = propose(rng, spl.proposal, model)
return Transition(model, proposal)
end

function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:Proposal},
model::DensityModel,
params_prev::Transition
)
proposal = propose(rng, spl.proposal, model, params_prev.params)
return Transition(model, proposal)
end

# Make a proposal from a NamedTuple of Proposal.
function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:NamedTuple},
model::DensityModel
)
proposal = _propose(rng, spl.proposal, model)
return Transition(model, proposal)
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel)
return propose(rng, sampler.proposal, model)
end

function propose(
rng::Random.AbstractRNG,
spl::MetropolisHastings{<:NamedTuple},
sampler::MHSampler,
model::DensityModel,
params_prev::Transition
transition_prev::Transition,
)
proposal = _propose(rng, spl.proposal, model, params_prev.params)
return Transition(model, proposal)
return propose(rng, sampler.proposal, model, transition_prev.params)
end

@generated function _propose(
rng::Random.AbstractRNG,
proposal::NamedTuple{names},
model::DensityModel
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
expr.args = Any[:($name = propose(rng, proposal.$name, model)) for name in names]
return expr
function transition(sampler::MHSampler, model::DensityModel, params)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity)
end

@generated function _propose(
rng::Random.AbstractRNG,
proposal::NamedTuple{names},
model::DensityModel,
params_prev::NamedTuple
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
expr.args = Any[
:($name = propose(rng, proposal.$name, model, params_prev.$name)) for name in names
]
return expr
function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real)
return Transition(params, logdensity)
end

transition(sampler, model, params) = transition(model, params)
transition(model, params) = Transition(model, params)

# Define the first sampling step.
# Return a 2-tuple consisting of the initial sample and the initial state.
# In this case they are identical.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
spl::MHSampler;
sampler::MHSampler;
init_params=nothing,
kwargs...
)
if init_params === nothing
transition = propose(rng, spl, model)
else
transition = AdvancedMH.transition(spl, model, init_params)
end

params = init_params === nothing ? propose(rng, sampler, model) : init_params
transition = AdvancedMH.transition(sampler, model, params)
return transition, transition
end

Expand All @@ -167,27 +90,30 @@ end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
spl::MHSampler,
params_prev::AbstractTransition;
sampler::MHSampler,
transition_prev::AbstractTransition;
kwargs...
)
# Generate a new proposal.
params = propose(rng, spl, model, params_prev)
candidate = propose(rng, sampler, model, transition_prev)

# Calculate the log acceptance probability.
logα = logdensity(model, params) - logdensity(model, params_prev) +
logratio_proposal_density(spl, params_prev, params)
# Calculate the log acceptance probability and the log density of the candidate.
logdensity_candidate = logdensity(model, candidate)
logα = logdensity_candidate - logdensity(model, transition_prev) +
logratio_proposal_density(sampler, transition_prev, candidate)

# Decide whether to return the previous params or the new one.
if -Random.randexp(rng) < logα
return params, params
transition = if -Random.randexp(rng) < logα
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate)
else
return params_prev, params_prev
transition_prev
end

return transition, transition
end

function logratio_proposal_density(
sampler::MetropolisHastings, params_prev::Transition, params::Transition
sampler::MetropolisHastings, transition_prev::AbstractTransition, candidate
)
return logratio_proposal_density(sampler.proposal, params_prev.params, params.params)
return logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
end
49 changes: 49 additions & 0 deletions src/proposal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,55 @@ function q(
return q(proposal(t_cond), t, t_cond)
end

####################
# Multiple proposals
####################

function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
)
return map(proposals) do proposal
return propose(rng, proposal, model)
end
end
function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
ts,
)
return map(proposals, ts) do proposal, t
return propose(rng, proposal, model, t)
end
end

@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
expr.args = Any[:($name = propose(rng, proposals.$name, model)) for name in names]
return expr
end

@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
ts,
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
expr.args = Any[
:($name = propose(rng, proposals.$name, model, ts.$name)) for name in names
]
return expr
end

"""
logratio_proposal_density(proposal, state, candidate)
Expand Down

2 comments on commit 1911b9d

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37741

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.1 -m "<description of version>" 1911b9d2d7116d974d04b31d5a530d2ddbbc9dcf
git push origin v0.6.1

Please sign in to comment.