Skip to content

Commit

Permalink
Switch design to immutable samplers (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 2, 2020
1 parent 509b5cf commit f03a17d
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 468 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "1.0.1"
version = "2.0.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
149 changes: 55 additions & 94 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,154 +6,115 @@ Concatenate multiple chains.
chainscat(c::AbstractChains...) = cat(c...; dims=3)

"""
sample_init!(rng, model, sampler, N[; kwargs...])
bundle_samples(samples, model, sampler, state, chain_type[; kwargs...])
Perform the initial setup of the MCMC `sampler` for the provided `model`.
Bundle all `samples` that were sampled from the `model` with the given `sampler` in a chain.
This function is not intended to return any value, any set up should mutate the `sampler`
or the `model` in-place. A common use for `sample_init!` might be to instantiate a particle
field for later use, or find an initial step size for a Hamiltonian sampler.
"""
function sample_init!(
::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
::Integer;
kwargs...
)
@debug "the default `sample_init!` function is used" typeof(model) typeof(sampler)
return
end
The final `state` of the `sampler` can be included in the chain. The type of the chain can
be specified with the `chain_type` argument.
By default, this method returns `samples`.
"""
sample_end!(rng, model, sampler, N, transitions[; kwargs...])
Perform final modifications after sampling from the MCMC `sampler` for the provided `model`,
resulting in the provided `transitions`.
This function is not intended to return any value, any set up should mutate the `sampler`
or the `model` in-place.
This function is useful in cases where you might want to transform the `transitions`,
save the `sampler` to disk, or perform any clean-up or finalization.
"""
function sample_end!(
::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
::Integer,
transitions;
kwargs...
)
@debug "the default `sample_end!` function is used" typeof(model) typeof(sampler) typeof(transitions)
return
end

function bundle_samples(
::Random.AbstractRNG,
samples,
::AbstractModel,
::AbstractSampler,
::Integer,
transitions,
::Type{Any};
::Any,
::Type;
kwargs...
)
return transitions
return samples
end

"""
step!(rng, model, sampler[, N = 1, transition = nothing; kwargs...])
Return the transition for the next step of the MCMC `sampler` for the provided `model`,
using the provided random number generator `rng`.
step(rng, model, sampler[, state; kwargs...])
Transitions describe the results of a single step of the `sampler`. As an example, a
transition might include a vector of parameters sampled from a prior distribution.
Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`.
The `step!` function may modify the `model` or the `sampler` in-place. For example, the
`sampler` may have a state variable that contains a vector of particles or some other value
that does not need to be included in the returned transition.
Samples describe the results of a single step of the `sampler`. As an example, a sample
might include a vector of parameters sampled from a prior distribution.
When sampling from the `sampler` using [`sample`](@ref), every `step!` call after the first
has access to the previous `transition`. In the first call, `transition` is set to `nothing`.
When sampling using [`sample`](@ref), every `step` call after the first has access to the
current `state` of the sampler.
"""
function step!(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
N::Integer = 1;
kwargs...
)
return step!(rng, model, sampler, N, nothing; kwargs...)
end
function step end

"""
transitions(transition, model, sampler, N[; kwargs...])
transitions(transition, model, sampler[; kwargs...])
samples(sample, model, sampler[, N; kwargs...])
Generate a container for the `N` transitions of the MCMC `sampler` for the provided
`model`, whose first transition is `transition`.
Generate a container for the samples of the MCMC `sampler` for the `model`, whose first
sample is `sample`.
The method can be called with and without a predefined size `N`.
The method can be called with and without a predefined number `N` of samples.
"""
function transitions(
transition,
function samples(
sample,
::AbstractModel,
::AbstractSampler,
N::Integer;
kwargs...
)
ts = Vector{typeof(transition)}(undef, 0)
ts = Vector{typeof(sample)}(undef, 0)
sizehint!(ts, N)
return ts
end

function transitions(
transition,
function samples(
sample,
::AbstractModel,
::AbstractSampler;
kwargs...
)
return Vector{typeof(transition)}(undef, 0)
return Vector{typeof(sample)}(undef, 0)
end

"""
save!!(transitions, transition, iteration, model, sampler, N[; kwargs...])
save!!(transitions, transition, iteration, model, sampler[; kwargs...])
save!!(samples, sample, iteration, model, sampler[, N; kwargs...])
Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
`transitions`.
Save the `sample` of the MCMC `sampler` at the current `iteration` in the container of
`samples`.
The function can be called with and without a predefined size `N`. By default, AbstractMCMC
uses ``push!!`` from the Julia package [BangBang](https://github.com/tkf/BangBang.jl) to
append to the container, and widen its type if needed.
The function can be called with and without a predefined number `N` of samples. By default,
AbstractMCMC uses ``push!!`` from the Julia package
[BangBang](https://github.com/tkf/BangBang.jl) to append to the container, and widen its
type if needed.
"""
function save!!(
transitions::Vector,
transition,
samples::Vector,
sample,
iteration::Integer,
::AbstractModel,
::AbstractSampler,
N::Integer;
kwargs...
)
new_ts = BangBang.push!!(transitions, transition)
new_ts !== transitions && sizehint!(new_ts, N)
return new_ts
s = BangBang.push!!(samples, sample)
s !== samples && sizehint!(s, N)
return s
end

function save!!(
transitions,
transition,
samples,
sample,
iteration::Integer,
::AbstractModel,
::AbstractSampler;
kwargs...
)
return BangBang.push!!(transitions, transition)
return BangBang.push!!(samples, sample)
end

Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) transitions(transition, model, sampler, N; kwargs...) false
Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) transitions(transition, model, sampler; kwargs...) false
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) save!!(transitions, transition, iteration, model, sampler; kwargs...) false
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) save!!(transitions, transition, iteration, model, sampler, N; kwargs...) false
# Deprecations
Base.@deprecate transitions(
transition,
model::AbstractModel,
sampler::AbstractSampler,
N::Integer;
kwargs...
) samples(transition, model, sampler, N; kwargs...) false
Base.@deprecate transitions(
transition,
model::AbstractModel,
sampler::AbstractSampler;
kwargs...
) samples(transition, model, sampler; kwargs...) false
Loading

2 comments on commit f03a17d

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17325

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.0 -m "<description of version>" f03a17d63fa794413cbf620ecc7ba46fba9b480b
git push origin v2.0.0

Please sign in to comment.