diff --git a/Project.toml b/Project.toml index ccb4ac94..a16d2596 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "1.0.1" +version = "2.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/interface.jl b/src/interface.jl index eac5a668..ea972a01 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,154 +6,115 @@ Concatenate multiple chains. chainscat(c::AbstractChains...) = cat(c...; dims=3) """ - sample_init!(rng, model, sampler, N[; kwargs...]) + bundle_samples(samples, model, sampler, state, chain_type[; kwargs...]) -Perform the initial setup of the MCMC `sampler` for the provided `model`. +Bundle all `samples` that were sampled from the `model` with the given `sampler` in a chain. -This function is not intended to return any value, any set up should mutate the `sampler` -or the `model` in-place. A common use for `sample_init!` might be to instantiate a particle -field for later use, or find an initial step size for a Hamiltonian sampler. -""" -function sample_init!( - ::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - ::Integer; - kwargs... -) - @debug "the default `sample_init!` function is used" typeof(model) typeof(sampler) - return -end +The final `state` of the `sampler` can be included in the chain. The type of the chain can +be specified with the `chain_type` argument. +By default, this method returns `samples`. """ - sample_end!(rng, model, sampler, N, transitions[; kwargs...]) - -Perform final modifications after sampling from the MCMC `sampler` for the provided `model`, -resulting in the provided `transitions`. - -This function is not intended to return any value, any set up should mutate the `sampler` -or the `model` in-place. - -This function is useful in cases where you might want to transform the `transitions`, -save the `sampler` to disk, or perform any clean-up or finalization. -""" -function sample_end!( - ::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - ::Integer, - transitions; - kwargs... -) - @debug "the default `sample_end!` function is used" typeof(model) typeof(sampler) typeof(transitions) - return -end - function bundle_samples( - ::Random.AbstractRNG, + samples, ::AbstractModel, ::AbstractSampler, - ::Integer, - transitions, - ::Type{Any}; + ::Any, + ::Type; kwargs... ) - return transitions + return samples end """ - step!(rng, model, sampler[, N = 1, transition = nothing; kwargs...]) - -Return the transition for the next step of the MCMC `sampler` for the provided `model`, -using the provided random number generator `rng`. + step(rng, model, sampler[, state; kwargs...]) -Transitions describe the results of a single step of the `sampler`. As an example, a -transition might include a vector of parameters sampled from a prior distribution. +Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. -The `step!` function may modify the `model` or the `sampler` in-place. For example, the -`sampler` may have a state variable that contains a vector of particles or some other value -that does not need to be included in the returned transition. +Samples describe the results of a single step of the `sampler`. As an example, a sample +might include a vector of parameters sampled from a prior distribution. -When sampling from the `sampler` using [`sample`](@ref), every `step!` call after the first -has access to the previous `transition`. In the first call, `transition` is set to `nothing`. +When sampling using [`sample`](@ref), every `step` call after the first has access to the +current `state` of the sampler. """ -function step!( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer = 1; - kwargs... -) - return step!(rng, model, sampler, N, nothing; kwargs...) -end +function step end """ - transitions(transition, model, sampler, N[; kwargs...]) - transitions(transition, model, sampler[; kwargs...]) + samples(sample, model, sampler[, N; kwargs...]) -Generate a container for the `N` transitions of the MCMC `sampler` for the provided -`model`, whose first transition is `transition`. +Generate a container for the samples of the MCMC `sampler` for the `model`, whose first +sample is `sample`. -The method can be called with and without a predefined size `N`. +The method can be called with and without a predefined number `N` of samples. """ -function transitions( - transition, +function samples( + sample, ::AbstractModel, ::AbstractSampler, N::Integer; kwargs... ) - ts = Vector{typeof(transition)}(undef, 0) + ts = Vector{typeof(sample)}(undef, 0) sizehint!(ts, N) return ts end -function transitions( - transition, +function samples( + sample, ::AbstractModel, ::AbstractSampler; kwargs... ) - return Vector{typeof(transition)}(undef, 0) + return Vector{typeof(sample)}(undef, 0) end """ - save!!(transitions, transition, iteration, model, sampler, N[; kwargs...]) - save!!(transitions, transition, iteration, model, sampler[; kwargs...]) + save!!(samples, sample, iteration, model, sampler[, N; kwargs...]) -Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of -`transitions`. +Save the `sample` of the MCMC `sampler` at the current `iteration` in the container of +`samples`. -The function can be called with and without a predefined size `N`. By default, AbstractMCMC -uses ``push!!`` from the Julia package [BangBang](https://github.com/tkf/BangBang.jl) to -append to the container, and widen its type if needed. +The function can be called with and without a predefined number `N` of samples. By default, +AbstractMCMC uses ``push!!`` from the Julia package +[BangBang](https://github.com/tkf/BangBang.jl) to append to the container, and widen its +type if needed. """ function save!!( - transitions::Vector, - transition, + samples::Vector, + sample, iteration::Integer, ::AbstractModel, ::AbstractSampler, N::Integer; kwargs... ) - new_ts = BangBang.push!!(transitions, transition) - new_ts !== transitions && sizehint!(new_ts, N) - return new_ts + s = BangBang.push!!(samples, sample) + s !== samples && sizehint!(s, N) + return s end function save!!( - transitions, - transition, + samples, + sample, iteration::Integer, ::AbstractModel, ::AbstractSampler; kwargs... ) - return BangBang.push!!(transitions, transition) + return BangBang.push!!(samples, sample) end -Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) transitions(transition, model, sampler, N; kwargs...) false -Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) transitions(transition, model, sampler; kwargs...) false -Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) save!!(transitions, transition, iteration, model, sampler; kwargs...) false -Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) save!!(transitions, transition, iteration, model, sampler, N; kwargs...) false +# Deprecations +Base.@deprecate transitions( + transition, + model::AbstractModel, + sampler::AbstractSampler, + N::Integer; + kwargs... +) samples(transition, model, sampler, N; kwargs...) false +Base.@deprecate transitions( + transition, + model::AbstractModel, + sampler::AbstractSampler; + kwargs... +) samples(transition, model, sampler; kwargs...) false diff --git a/src/sample.jl b/src/sample.jl index 3f04b5f2..f8bd6370 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -52,7 +52,7 @@ Return `N` samples from the MCMC `sampler` for the provided `model`. A callback function `f` with type signature ```julia -f(rng, model, sampler, transition, iteration) +f(rng, model, sampler, sample, iteration) ``` may be provided as keyword argument `callback`. It is called after every sampling step. """ @@ -63,50 +63,44 @@ function mcmcsample( N::Integer; progress = true, progressname = "Sampling", - callback = (args...) -> nothing, + callback = nothing, chain_type::Type=Any, kwargs... ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") - # Perform any necessary setup. - sample_init!(rng, model, sampler, N; kwargs...) - @ifwithprogresslogger progress name=progressname begin - # Obtain the initial transition. - transition = step!(rng, model, sampler, N; iteration=1, kwargs...) + # Obtain the initial sample and state. + sample, state = step(rng, model, sampler; kwargs...) # Run callback. - callback(rng, model, sampler, transition, 1) + callback === nothing || callback(rng, model, sampler, sample, 1) - # Save the transition. - transitions = AbstractMCMC.transitions(transition, model, sampler, N; kwargs...) - transitions = save!!(transitions, transition, 1, model, sampler, N; kwargs...) + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) + samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) # Update the progress bar. progress && ProgressLogging.@logprogress 1/N # Step through the sampler. for i in 2:N - # Obtain the next transition. - transition = step!(rng, model, sampler, N, transition; iteration=i, kwargs...) + # Obtain the next sample and state. + sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback(rng, model, sampler, transition, i) + callback === nothing || callback(rng, model, sampler, sample, i) - # Save the transition. - transitions = save!!(transitions, transition, i, model, sampler, N; kwargs...) + # Save the sample. + samples = save!!(samples, sample, i, model, sampler, N; kwargs...) # Update the progress bar. progress && ProgressLogging.@logprogress i/N end end - # Wrap up the sampler, if necessary. - sample_end!(rng, model, sampler, N, transitions; kwargs...) - - return bundle_samples(rng, model, sampler, N, transitions, chain_type; kwargs...) + return bundle_samples(samples, model, sampler, state, chain_type; kwargs...) end """ @@ -116,13 +110,13 @@ Continuously draw samples until a convergence criterion `isdone` returns `true`. The function `isdone` has the signature ```julia -isdone(rng, model, sampler, transitions, iteration; kwargs...) +isdone(rng, model, sampler, samples, iteration; kwargs...) ``` and should return `true` when sampling should end, and `false` otherwise. A callback function `f` with type signature ```julia -f(rng, model, sampler, transition, iteration) +f(rng, model, sampler, sample, iteration) ``` may be provided as keyword argument `callback`. It is called after every sampling step. """ @@ -134,46 +128,40 @@ function mcmcsample( chain_type::Type=Any, progress = true, progressname = "Convergence sampling", - callback = (args...) -> nothing, + callback = nothing, kwargs... ) - # Perform any necessary setup. - sample_init!(rng, model, sampler, 1; kwargs...) - @ifwithprogresslogger progress name=progressname begin - # Obtain the initial transition. - transition = step!(rng, model, sampler, 1; iteration=1, kwargs...) + # Obtain the initial sample and state. + sample, state = step(rng, model, sampler; kwargs...) # Run callback. - callback(rng, model, sampler, transition, 1) + callback === nothing || callback(rng, model, sampler, sample, 1) - # Save the transition. - transitions = AbstractMCMC.transitions(transition, model, sampler; kwargs...) - transitions = save!!(transitions, transition, 1, model, sampler; kwargs...) + # Save the sample. + samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) + samples = save!!(samples, sample, 1, model, sampler; kwargs...) # Step through the sampler until stopping. i = 2 - while !isdone(rng, model, sampler, transitions, i; progress=progress, kwargs...) - # Obtain the next transition. - transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...) + while !isdone(rng, model, sampler, samples, i; progress=progress, kwargs...) + # Obtain the next sample and state. + sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback(rng, model, sampler, transition, i) + callback === nothing || callback(rng, model, sampler, sample, i) - # Save the transition. - transitions = save!!(transitions, transition, i, model, sampler; kwargs...) + # Save the sample. + samples = save!!(samples, sample, i, model, sampler; kwargs...) # Increment iteration counter. i += 1 end end - # Wrap up the sampler, if necessary. - sample_end!(rng, model, sampler, i, transitions; kwargs...) - # Wrap the samples up. - return bundle_samples(rng, model, sampler, i, transitions, chain_type; kwargs...) + return bundle_samples(samples, model, sampler, state, chain_type; kwargs...) end """ @@ -237,24 +225,26 @@ function mcmcsample( end Distributed.@async begin - Threads.@threads for i in 1:nchains - # Obtain the ID of the current thread. - id = Threads.threadid() + try + Threads.@threads for i in 1:nchains + # Obtain the ID of the current thread. + id = Threads.threadid() - # Seed the thread-specific random number generator with the pre-made seed. - subrng = rngs[id] - Random.seed!(subrng, seeds[i]) + # Seed the thread-specific random number generator with the pre-made seed. + subrng = rngs[id] + Random.seed!(subrng, seeds[i]) - # Sample a chain and save it to the vector. - chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; - progress = false, kwargs...) + # Sample a chain and save it to the vector. + chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; + progress = false, kwargs...) - # Update the progress bar. - progress && put!(channel, true) + # Update the progress bar. + progress && put!(channel, true) + end + finally + # Stop updating the progress bar. + progress && put!(channel, false) end - - # Stop updating the progress bar. - progress && put!(channel, false) end end end @@ -308,15 +298,13 @@ function mcmcsample( end Distributed.@async begin - chains = let rng=rng, model=model, sampler=sampler, N=N, channel=channel, - kwargs=kwargs - Distributed.pmap(pool, seeds) do seed + try + chains = Distributed.pmap(pool, seeds) do seed # Seed a new random number generator with the pre-made seed. - subrng = deepcopy(rng) - Random.seed!(subrng, seed) + Random.seed!(rng, seed) # Sample a chain. - chain = StatsBase.sample(subrng, model, sampler, N; + chain = StatsBase.sample(rng, model, sampler, N; progress = false, kwargs...) # Update the progress bar. @@ -325,10 +313,10 @@ function mcmcsample( # Return the new chain. return chain end + finally + # Stop updating the progress bar. + progress && put!(channel, false) end - - # Stop updating the progress bar. - progress && put!(channel, false) end end end @@ -336,8 +324,3 @@ function mcmcsample( # Concatenate the chains together. return reduce(chainscat, chains) end - -# Deprecations. -Base.@deprecate psample(model, sampler, N, nchains; kwargs...) sample(model, sampler, MCMCThreads(), N, nchains; kwargs...) false -Base.@deprecate psample(rng, model, sampler, N, nchains; kwargs...) sample(rng, model, sampler, MCMCThreads(), N, nchains; kwargs...) false -Base.@deprecate mcmcpsample(rng, model, sampler, N, nchains; kwargs...) mcmcsample(rng, model, sampler, MCMCThreads(), N, nchains; kwargs...) false diff --git a/src/stepper.jl b/src/stepper.jl index 88463f04..e15ec0e4 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -1,47 +1,46 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} rng::A model::M - s::S + sampler::S kwargs::K end -function Base.iterate(stp::Stepper, state=nothing) - t = step!(stp.rng, stp.model, stp.s, 1, state; stp.kwargs...) - return t, t +Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...) +function Base.iterate(stp::Stepper, state) + return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...) end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() """ - steps!([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, kwargs...) + steps([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, kwargs...) -`steps!` returns an iterator that returns samples continuously, after calling `sample_init!`. +Return an iterator that returns samples continuously. -Usage: +# Examples ```julia -for transition in steps!(MyModel(), MySampler()) +for transition in steps(MyModel(), MySampler()) println(transition) # Do other stuff with transition below. end ``` """ -function steps!( +function steps( model::AbstractModel, - s::AbstractSampler, + sampler::AbstractSampler, kwargs... ) - return steps!(Random.GLOBAL_RNG, model, s; kwargs...) + return steps(Random.GLOBAL_RNG, model, sampler; kwargs...) end -function steps!( +function steps( rng::Random.AbstractRNG, model::AbstractModel, - s::AbstractSampler, + sampler::AbstractSampler, kwargs... ) - sample_init!(rng, model, s, 0) - return Stepper(rng, model, s, kwargs) + return Stepper(rng, model, sampler, kwargs) end diff --git a/src/transducer.jl b/src/transducer.jl index 60841459..2acdd176 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -15,21 +15,23 @@ function Sample( sampler::AbstractSampler; kwargs... ) - sample_init!(rng, model, sampler, 0) return Sample(rng, model, sampler, kwargs) end function Transducers.start(rf::Transducers.R_{<:Sample}, result) - return Transducers.wrap(rf, nothing, Transducers.start(Transducers.inner(rf), result)) + sampler = Transducers.xform(rf) + return Transducers.wrap( + rf, + step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...), + Transducers.start(Transducers.inner(rf), result), + ) end function Transducers.next(rf::Transducers.R_{<:Sample}, result, input) t = Transducers.xform(rf) - Transducers.wrapping(rf, result) do state, iresult - transition = step!(t.rng, t.model, t.sampler, 1, state; t.kwargs...) - iinput = transition - iresult = Transducers.next(Transducers.inner(rf), iresult, transition) - return transition, iresult + Transducers.wrapping(rf, result) do (sample, state), iresult + iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample) + return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2 end end diff --git a/test/deprecations.jl b/test/deprecations.jl new file mode 100644 index 00000000..f866668c --- /dev/null +++ b/test/deprecations.jl @@ -0,0 +1,4 @@ +@testset "deprecations.jl" begin + @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) + @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0e701448..c3f108e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,4 @@ using AbstractMCMC -using AbstractMCMC: steps! using Atom.Progress: JunoProgressLogger using ConsoleProgressMonitor: ProgressLogger using IJulia @@ -17,266 +16,11 @@ using Test: collect_test_logs const LOGGERS = Set() const CURRENT_LOGGER = Logging.current_logger() -include("interface.jl") +include("utils.jl") @testset "AbstractMCMC" begin - @testset "Basic sampling" begin - @testset "REPL" begin - empty!(LOGGERS) - - Random.seed!(1234) - N = 1_000 - chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) - - @test length(LOGGERS) == 1 - logger = first(LOGGERS) - @test logger isa TeeLogger - @test logger.loggers[1].logger isa TerminalLogger - @test logger.loggers[2].logger === CURRENT_LOGGER - @test Logging.current_logger() === CURRENT_LOGGER - - # test output type and size - @test chain isa Vector{<:MyTransition} - @test length(chain) == N - - # test some statistical properties - tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 - end - - @testset "Juno" begin - empty!(LOGGERS) - - Random.seed!(1234) - N = 10 - - logger = JunoProgressLogger() - Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) - end - - @test length(LOGGERS) == 1 - @test first(LOGGERS) === logger - @test Logging.current_logger() === CURRENT_LOGGER - end - - @testset "IJulia" begin - # emulate running IJulia kernel - @eval IJulia begin - inited = true - end - - empty!(LOGGERS) - - Random.seed!(1234) - N = 10 - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) - - @test length(LOGGERS) == 1 - logger = first(LOGGERS) - @test logger isa TeeLogger - @test logger.loggers[1].logger isa ProgressLogger - @test logger.loggers[2].logger === CURRENT_LOGGER - @test Logging.current_logger() === CURRENT_LOGGER - - @eval IJulia begin - inited = false - end - end - - @testset "Custom logger" begin - empty!(LOGGERS) - - Random.seed!(1234) - N = 10 - - logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) - Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) - end - - @test length(LOGGERS) == 1 - @test first(LOGGERS) === logger - @test Logging.current_logger() === CURRENT_LOGGER - end - - @testset "Suppress output" begin - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; progress = false, sleepy = true) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) - end - end - - if VERSION ≥ v"1.3" - @testset "Multithreaded sampling" begin - if Threads.nthreads() == 1 - warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10; chain_type = MyChain) - end - - Random.seed!(1234) - N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) - - # test output type and size - @test chains isa Vector{<:MyChain} - @test length(chains) == 1000 - @test all(x -> length(x.as) == length(x.bs) == N, chains) - - # test some statistical properties - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains) - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) - - # test reproducibility - Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - - # Unexpected order of arguments. - str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) - - # Suppress output. - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) - - # Smoke test for nchains < nthreads - if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) - end - end - end - - @testset "Multicore sampling" begin - if nworkers() == 1 - warnregex = r"^Only a single process available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(), - 10, 10; chain_type = MyChain) - end - - # Add worker processes. - addprocs() - - # Load all required packages (`interface.jl` needs Random). - @everywhere begin - using AbstractMCMC - using AbstractMCMC: sample - - using Random - include("interface.jl") - end - - N = 10_000 - Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) - - # Test output type and size. - @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) - @test length(chains) == 1000 - @test all(x -> length(x.as) == length(x.bs) == N, chains) - - # Test some statistical properties. - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains) - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) - - # Test reproducibility. - Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - - # Unexpected order of arguments. - str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCDistributed(), 5, 10; - chain_type = MyChain) - - # Suppress output. - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100; - progress = false, chain_type = MyChain) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) - end - - @testset "Chain constructors" begin - chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) - chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) - - @test chain1 isa Vector{<:MyTransition} - @test chain2 isa MyChain - end - - @testset "Iterator sampling" begin - Random.seed!(1234) - as = [] - bs = [] - - iter = steps!(MyModel(), MySampler()) - - for (count, t) in enumerate(iter) - if count >= 1000 - break - end - - # don't save missing values - t.a === missing && continue - - push!(as, t.a) - push!(bs, t.b) - end - - @test length(as) == length(bs) == 998 - - @test mean(as) ≈ 0.5 atol=1e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 - - println(eltype(iter)) - @test Base.IteratorSize(iter) == Base.IsInfinite() - @test Base.IteratorEltype(iter) == Base.EltypeUnknown() - end - - @testset "Sample without predetermined N" begin - Random.seed!(1234) - chain = sample(MyModel(), MySampler()) - bmean = mean(x.b for x in chain) - @test abs(bmean) <= 0.001 && length(chain) < 10_000 - end - - @testset "Deprecations" begin - @test_deprecated AbstractMCMC.psample(MyModel(), MySampler(), 10, 10; - chain_type = MyChain) - @test_deprecated AbstractMCMC.psample(Random.GLOBAL_RNG, MyModel(), MySampler(), - 10, 10; - chain_type = MyChain) - @test_deprecated AbstractMCMC.mcmcpsample(Random.GLOBAL_RNG, MyModel(), - MySampler(), 10, 10; - chain_type = MyChain) - end - + include("sample.jl") + include("stepper.jl") include("transducer.jl") + include("deprecations.jl") end diff --git a/test/sample.jl b/test/sample.jl new file mode 100644 index 00000000..9e1eb726 --- /dev/null +++ b/test/sample.jl @@ -0,0 +1,217 @@ +@testset "sample.jl" begin + @testset "Basic sampling" begin + @testset "REPL" begin + empty!(LOGGERS) + + Random.seed!(1234) + N = 1_000 + chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + + @test length(LOGGERS) == 1 + logger = first(LOGGERS) + @test logger isa TeeLogger + @test logger.loggers[1].logger isa TerminalLogger + @test logger.loggers[2].logger === CURRENT_LOGGER + @test Logging.current_logger() === CURRENT_LOGGER + + # test output type and size + @test chain isa Vector{<:MySample} + @test length(chain) == N + + # test some statistical properties + tail_chain = @view chain[2:end] + @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + end + + @testset "Juno" begin + empty!(LOGGERS) + + Random.seed!(1234) + N = 10 + + logger = JunoProgressLogger() + Logging.with_logger(logger) do + sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + end + + @test length(LOGGERS) == 1 + @test first(LOGGERS) === logger + @test Logging.current_logger() === CURRENT_LOGGER + end + + @testset "IJulia" begin + # emulate running IJulia kernel + @eval IJulia begin + inited = true + end + + empty!(LOGGERS) + + Random.seed!(1234) + N = 10 + sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + + @test length(LOGGERS) == 1 + logger = first(LOGGERS) + @test logger isa TeeLogger + @test logger.loggers[1].logger isa ProgressLogger + @test logger.loggers[2].logger === CURRENT_LOGGER + @test Logging.current_logger() === CURRENT_LOGGER + + @eval IJulia begin + inited = false + end + end + + @testset "Custom logger" begin + empty!(LOGGERS) + + Random.seed!(1234) + N = 10 + + logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) + Logging.with_logger(logger) do + sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + end + + @test length(LOGGERS) == 1 + @test first(LOGGERS) === logger + @test Logging.current_logger() === CURRENT_LOGGER + end + + @testset "Suppress output" begin + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample(MyModel(), MySampler(), 100; progress = false, sleepy = true) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) + end + end + + if VERSION ≥ v"1.3" + @testset "Multithreaded sampling" begin + if Threads.nthreads() == 1 + warnregex = r"^Only a single thread available" + @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), + 10, 10; chain_type = MyChain) + end + + Random.seed!(1234) + N = 10_000 + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) + + # test output type and size + @test chains isa Vector{<:MyChain} + @test length(chains) == 1000 + @test all(x -> length(x.as) == length(x.bs) == N, chains) + + # test some statistical properties + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains) + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + + # test reproducibility + Random.seed!(1234) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) + + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + + # Unexpected order of arguments. + str = "Number of chains (10) is greater than number of samples per chain (5)" + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), + MCMCThreads(), 5, 10; + chain_type = MyChain) + + # Suppress output. + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; + progress = false, chain_type = MyChain) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # Smoke test for nchains < nthreads + if Threads.nthreads() == 2 + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) + end + end + end + + @testset "Multicore sampling" begin + if nworkers() == 1 + warnregex = r"^Only a single process available" + @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(), + 10, 10; chain_type = MyChain) + end + + # Add worker processes. + addprocs() + + # Load all required packages (`interface.jl` needs Random). + @everywhere begin + using AbstractMCMC + using AbstractMCMC: sample + + using Random + include("utils.jl") + end + + N = 10_000 + Random.seed!(1234) + chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; + chain_type = MyChain) + + # Test output type and size. + @test chains isa Vector{<:MyChain} + @test all(c.as[1] === missing for c in chains) + @test length(chains) == 1000 + @test all(x -> length(x.as) == length(x.bs) == N, chains) + + # Test some statistical properties. + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains) + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + + # Test reproducibility. + Random.seed!(1234) + chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; + chain_type = MyChain) + + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + + # Unexpected order of arguments. + str = "Number of chains (10) is greater than number of samples per chain (5)" + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), + MCMCDistributed(), 5, 10; + chain_type = MyChain) + + # Suppress output. + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100; + progress = false, chain_type = MyChain) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) + end + + @testset "Chain constructors" begin + chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) + chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) + + @test chain1 isa Vector{<:MySample} + @test chain2 isa MyChain + end + + @testset "Sample without predetermined N" begin + Random.seed!(1234) + chain = sample(MyModel(), MySampler()) + bmean = mean(x.b for x in chain) + @test abs(bmean) <= 0.001 && length(chain) < 10_000 + end +end \ No newline at end of file diff --git a/test/stepper.jl b/test/stepper.jl new file mode 100644 index 00000000..9d3736f5 --- /dev/null +++ b/test/stepper.jl @@ -0,0 +1,31 @@ +@testset "stepper.jl" begin + @testset "Iterator sampling" begin + Random.seed!(1234) + as = [] + bs = [] + + iter = AbstractMCMC.steps(MyModel(), MySampler()) + + for (count, t) in enumerate(iter) + if count >= 1000 + break + end + + # don't save missing values + t.a === missing && continue + + push!(as, t.a) + push!(bs, t.b) + end + + @test length(as) == length(bs) == 998 + + @test mean(as) ≈ 0.5 atol=1e-2 + @test var(as) ≈ 1 / 12 atol=5e-3 + @test mean(bs) ≈ 0.0 atol=5e-2 + @test var(bs) ≈ 1 atol=5e-2 + + @test Base.IteratorSize(iter) == Base.IsInfinite() + @test Base.IteratorEltype(iter) == Base.EltypeUnknown() + end +end \ No newline at end of file diff --git a/test/transducer.jl b/test/transducer.jl index 03d7dced..943ceae3 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -11,7 +11,7 @@ end # test output type and size - @test chain isa Vector{<:MyTransition} + @test chain isa Vector{<:MySample} @test length(chain) == N # test some statistical properties @@ -25,7 +25,7 @@ @testset "drop" begin xf = AbstractMCMC.Sample(MyModel(), MySampler()) chain = collect(xf |> Drop(1), 1:10) - @test chain isa Vector{MyTransition{Float64,Float64}} + @test chain isa Vector{MySample{Float64,Float64}} @test length(chain) == 9 end @@ -33,7 +33,7 @@ @testset "iterator example" begin # filter missing values and split transitions xf = AbstractMCMC.Sample(MyModel(), MySampler()) |> - OfType(MyTransition{Float64,Float64}) |> Map(x -> (x.a, x.b)) + OfType(MySample{Float64,Float64}) |> Map(x -> (x.a, x.b)) as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b) push!(as, a) push!(bs, b) diff --git a/test/interface.jl b/test/utils.jl similarity index 55% rename from test/interface.jl rename to test/utils.jl index 3426f494..bd9050f1 100644 --- a/test/interface.jl +++ b/test/utils.jl @@ -1,6 +1,6 @@ struct MyModel <: AbstractMCMC.AbstractModel end -struct MyTransition{A,B} +struct MySample{A,B} a::A b::B end @@ -13,55 +13,62 @@ struct MyChain{A,B} <: AbstractMCMC.AbstractChains bs::Vector{B} end -function AbstractMCMC.step!( +function AbstractMCMC.step( rng::AbstractRNG, model::MyModel, sampler::MySampler, - N::Integer, - transition::Union{Nothing,MyTransition}; + state::Union{Nothing,Integer} = nothing; sleepy = false, loggers = false, kwargs... ) # sample `a` is missing in the first step - a = transition === nothing ? missing : rand(rng) + a = state === nothing ? missing : rand(rng) b = randn(rng) loggers && push!(LOGGERS, Logging.current_logger()) sleepy && sleep(0.001) - return MyTransition(a, b) + _state = state === nothing ? 1 : state + 1 + + return MySample(a, b), _state end function AbstractMCMC.bundle_samples( - rng::AbstractRNG, + samples::Vector{<:MySample}, model::MyModel, sampler::MySampler, - N::Integer, - transitions::Vector{<:MyTransition}, - chain_type::Type{MyChain}; + ::Any, + ::Type{MyChain}; kwargs... ) - as = [t.a for t in transitions] - bs = [t.b for t in transitions] + as = [t.a for t in samples] + bs = [t.b for t in samples] return MyChain(as, bs) end -function is_done( +function isdone( rng::AbstractRNG, model::MyModel, s::MySampler, - transitions, + samples, iteration::Int; - chain_type::Type=Any, kwargs... ) # Calculate the mean of x.b. - bmean = mean(x.b for x in transitions) + bmean = mean(x.b for x in samples) return abs(bmean) <= 0.001 || iteration >= 10_000 end # Set a default convergence function. -AbstractMCMC.sample(model, sampler::MySampler; kwargs...) = sample(Random.GLOBAL_RNG, model, sampler, is_done; kwargs...) -AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...) +function AbstractMCMC.sample(model, sampler::MySampler; kwargs...) + return sample(Random.GLOBAL_RNG, model, sampler, isdone; kwargs...) +end + +function AbstractMCMC.chainscat( + chain::Union{MyChain,Vector{<:MyChain}}, + chains::Union{MyChain,Vector{<:MyChain}}... +) + return vcat(chain, chains...) +end