Skip to content

Commit

Permalink
Support log density functions as models (#113)
Browse files Browse the repository at this point in the history
* Update sample.jl

* Update sample.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update api.md

* Update stepper.jl

* Update transducer.jl

* Update api.md

* Update src/stepper.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/transducer.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Project.toml

* Update src/sample.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Reorganize fallbacks

* Add tests

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Project.toml

* Define utilities on all workers

* Update test/sample.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2023
1 parent 2d31f09 commit 33487da
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.3.0"
version = "4.4.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
26 changes: 25 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,39 @@

AbstractMCMC defines an interface for sampling Markov chains.

## Model

```@docs
AbstractMCMC.AbstractModel
AbstractMCMC.LogDensityModel
```

## Sampler

```@docs
AbstractMCMC.AbstractSampler
```

## Sampling a single chain

```@docs
AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Integer)
AbstractMCMC.sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler, ::Any)
AbstractMCMC.sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler, ::Any)
```

### Iterator

```@docs
AbstractMCMC.steps(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler)
AbstractMCMC.steps(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler)
```

### Transducer

```@docs
AbstractMCMC.Sample(::AbstractRNG, ::AbstractMCMC.AbstractModel, ::AbstractMCMC.AbstractSampler)
AbstractMCMC.Sample(::AbstractRNG, ::Any, ::AbstractMCMC.AbstractSampler)
```

## Sampling multiple chains in parallel
Expand All @@ -32,6 +48,14 @@ AbstractMCMC.sample(
::Integer,
::Integer,
)
AbstractMCMC.sample(
::AbstractRNG,
::Any,
::AbstractMCMC.AbstractSampler,
::AbstractMCMC.AbstractMCMCEnsemble,
::Integer,
::Integer,
)
```

Two algorithms are provided for parallel sampling with multiple threads and multiple processes, and one allows for the user to sample multiple chains in serial (no parallelization):
Expand Down
92 changes: 92 additions & 0 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,95 @@ struct LogDensityModel{L} <: AbstractModel
end

LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity)

# Fallbacks: Wrap log density function in a model
"""
sample(
rng::Random.AbstractRNG=Random.default_rng(),
logdensity,
sampler::AbstractSampler,
N_or_isdone;
kwargs...,
)
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`.
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
function StatsBase.sample(
rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
)
return StatsBase.sample(rng, _model(logdensity), sampler, N_or_isdone; kwargs...)
end

"""
sample(
rng::Random.AbstractRNG=Random.default_rng(),
logdensity,
sampler::AbstractSampler,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...,
)
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `sample` with the resulting model instead of `logdensity`.
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
logdensity,
sampler::AbstractSampler,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...,
)
return StatsBase.sample(
rng, _model(logdensity), sampler, parallel, N, nchains; kwargs...
)
end

"""
steps(
rng::Random.AbstractRNG=Random.default_rng(),
logdensity,
sampler::AbstractSampler;
kwargs...,
)
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `steps` with the resulting model instead of `logdensity`.
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
function steps(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...)
return steps(rng, _model(logdensity), sampler; kwargs...)
end

"""
Sample(
rng::Random.AbstractRNG=Random.default_rng(),
logdensity,
sampler::AbstractSampler;
kwargs...,
)
Wrap the `logdensity` function in a [`LogDensityModel`](@ref), and call `Sample` with the resulting model instead of `logdensity`.
The `logdensity` function has to support the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
function Sample(rng::Random.AbstractRNG, logdensity, sampler::AbstractSampler; kwargs...)
return Sample(rng, _model(logdensity), sampler; kwargs...)
end

function _model(logdensity)
if LogDensityProblems.capabilities(logdensity) === nothing
throw(
ArgumentError(
"the log density function does not support the LogDensityProblems.jl interface. Please implement the interface or provide a model of type `AbstractMCMC.AbstractModel`",
),
)
end
return LogDensityModel(logdensity)
end
53 changes: 29 additions & 24 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,29 @@ function setprogress!(progress::Bool)
return progress
end

function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...)
return StatsBase.sample(Random.default_rng(), model, sampler, arg; kwargs...)
end

"""
sample([rng, ]model, sampler, N; kwargs...)
Return `N` samples from the `model` with the Markov chain Monte Carlo `sampler`.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
N::Integer;
kwargs...,
model_or_logdensity, sampler::AbstractSampler, N_or_isdone; kwargs...
)
return mcmcsample(rng, model, sampler, N; kwargs...)
return StatsBase.sample(
Random.default_rng(), model_or_logdensity, sampler, N_or_isdone; kwargs...
)
end

"""
sample([rng, ]model, sampler, isdone; kwargs...)
sample(
rng::Random.AbatractRNG=Random.default_rng(),
model::AbstractModel,
sampler::AbstractSampler,
N_or_isdone;
kwargs...,
)
Sample from the `model` with the Markov chain Monte Carlo `sampler` and return the samples.
Sample from the `model` with the Markov chain Monte Carlo `sampler` until a
convergence criterion `isdone` returns `true`, and return the samples.
If `N_or_isdone` is an `Integer`, exactly `N_or_isdone` samples are returned.
The function `isdone` has the signature
Otherwise, sampling is performed until a convergence criterion `N_or_isdone` returns `true`.
The convergence criterion has to be a function with the signature
```julia
isdone(rng, model, sampler, samples, state, iteration; kwargs...)
```
Expand All @@ -48,27 +45,35 @@ function StatsBase.sample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
isdone;
N_or_isdone;
kwargs...,
)
return mcmcsample(rng, model, sampler, isdone; kwargs...)
return mcmcsample(rng, model, sampler, N_or_isdone; kwargs...)
end

function StatsBase.sample(
model::AbstractModel,
model_or_logdensity,
sampler::AbstractSampler,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...,
)
return StatsBase.sample(
Random.default_rng(), model, sampler, parallel, N, nchains; kwargs...
Random.default_rng(), model_or_logdensity, sampler, parallel, N, nchains; kwargs...
)
end

"""
sample([rng, ]model, sampler, parallel, N, nchains; kwargs...)
sample(
rng::Random.AbstractRNG=Random.default_rng(),
model::AbstractModel,
sampler::AbstractSampler,
parallel::AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
kwargs...,
)
Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel
using the `parallel` algorithm, and combine them into a single chain.
Expand Down
11 changes: 8 additions & 3 deletions src/stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@ end
Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown()

function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...)
return steps(Random.default_rng(), model, sampler; kwargs...)
function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...)
return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...)
end

"""
steps([rng, ]model, sampler; kwargs...)
steps(
rng::Random.AbstractRNG=Random.default_rng(),
model::AbstractModel,
sampler::AbstractSampler;
kwargs...,
)
Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo
`sampler`.
Expand Down
11 changes: 8 additions & 3 deletions src/transducer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <:
kwargs::K
end

function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...)
return Sample(Random.default_rng(), model, sampler; kwargs...)
function Sample(model_or_logdensity, sampler::AbstractSampler; kwargs...)
return Sample(Random.default_rng(), model_or_logdensity, sampler; kwargs...)
end

"""
Sample([rng, ]model, sampler; kwargs...)
Sample(
rng::Random.AbstractRNG=Random.default_rng(),
model::AbstractModel,
sampler::AbstractSampler;
kwargs...,
)
Create a transducer that returns samples from the `model` with the Markov chain Monte Carlo
`sampler`.
Expand Down
90 changes: 90 additions & 0 deletions test/logdensityproblems.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
@testset "logdensityproblems.jl" begin
# Add worker processes.
# Memory requirements on Windows are ~4x larger than on Linux, hence number of processes is reduced
# See, e.g., https://github.com/JuliaLang/julia/issues/40766 and https://github.com/JuliaLang/Pkg.jl/pull/2366
pids = addprocs(Sys.iswindows() ? div(Sys.CPU_THREADS::Int, 2) : Sys.CPU_THREADS::Int)

# Load all required packages (`utils.jl` needs LogDensityProblems, Logging, and Random).
@everywhere begin
using AbstractMCMC
using AbstractMCMC: sample
using LogDensityProblems

using Logging
using Random
include("utils.jl")
end

@testset "LogDensityModel" begin
= MyLogDensity(10)
model = @inferred AbstractMCMC.LogDensityModel(ℓ)
@test model isa AbstractMCMC.LogDensityModel{MyLogDensity}
@test model.logdensity ===

@test_throws ArgumentError AbstractMCMC.LogDensityModel(mylogdensity)
end

@testset "fallback for log densities" begin
# Sample with log density
dim = 10
= MyLogDensity(dim)
Random.seed!(1234)
N = 1_000
samples = sample(ℓ, MySampler(), N)

# Samples are of the correct dimension and log density values are correct
@test length(samples) == N
@test all(length(x.a) == dim for x in samples)
@test all(x.b LogDensityProblems.logdensity(ℓ, x.a) for x in samples)

# Same chain as if LogDensityModel is used explicitly
Random.seed!(1234)
samples2 = sample(AbstractMCMC.LogDensityModel(ℓ), MySampler(), N)
@test length(samples2) == N
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples2))

# Same chain if sampling is performed with convergence criterion
Random.seed!(1234)
isdone(rng, model, sampler, state, samples, iteration; kwargs...) = iteration > N
samples3 = sample(ℓ, MySampler(), isdone)
@test length(samples3) == N
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples3))

# Same chain if sampling is performed with iterator
Random.seed!(1234)
samples4 = collect(Iterators.take(AbstractMCMC.steps(ℓ, MySampler()), N))
@test length(samples4) == N
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples4))

# Same chain if sampling is performed with transducer
Random.seed!(1234)
xf = AbstractMCMC.Sample(ℓ, MySampler())
samples5 = collect(xf(1:N))
@test length(samples5) == N
@test all(x.a == y.a && x.b == y.b for (x, y) in zip(samples, samples5))

# Parallel sampling
for alg in (MCMCSerial(), MCMCDistributed(), MCMCThreads())
chains = sample(ℓ, MySampler(), alg, N, 2)
@test length(chains) == 2
samples = vcat(chains[1], chains[2])
@test length(samples) == 2 * N
@test all(length(x.a) == dim for x in samples)
@test all(x.b LogDensityProblems.logdensity(ℓ, x.a) for x in samples)
end

# Log density has to satisfy the LogDensityProblems interface
@test_throws ArgumentError sample(mylogdensity, MySampler(), N)
@test_throws ArgumentError sample(mylogdensity, MySampler(), isdone)
@test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCSerial(), N, 2)
@test_throws ArgumentError sample(mylogdensity, MySampler(), MCMCThreads(), N, 2)
@test_throws ArgumentError sample(
mylogdensity, MySampler(), MCMCDistributed(), N, 2
)
@test_throws ArgumentError AbstractMCMC.steps(mylogdensity, MySampler())
@test_throws ArgumentError AbstractMCMC.Sample(mylogdensity, MySampler())
end

# Remove workers
rmprocs(pids...)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using AbstractMCMC
using Atom.Progress: JunoProgressLogger
using ConsoleProgressMonitor: ProgressLogger
using IJulia
using LogDensityProblems
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger
using Transducers
Expand All @@ -22,4 +23,5 @@ include("utils.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")
include("logdensityproblems.jl")
end
Loading

2 comments on commit 33487da

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75447

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.4.0 -m "<description of version>" 33487da76d9874adb7bee1b0509d0a3172580c9a
git push origin v4.4.0

Please sign in to comment.