From 8d7f22f5a047a16b6870ebb15c0090331db8dcaa Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Jun 2022 13:13:30 +0200 Subject: [PATCH] Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer (#102) * Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer * Fix test errors on Julia < 1.6 * Only enable progress logging on Julia < 1.6 * Use different seed * Update api.md * Update api.md * Update sample.jl * Use `==` instead of `===` --- Project.toml | 2 +- docs/src/api.md | 6 ++- src/sample.jl | 4 +- src/stepper.jl | 32 +++++++++++++- src/transducer.jl | 52 +++++++++++++++++++---- test/sample.jl | 101 +++++++++++++++++++++++++++++++++------------ test/stepper.jl | 42 +++++++++++++++++++ test/transducer.jl | 46 +++++++++++++++++++++ 8 files changed, 243 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 6dce9f2c..69a7fb83 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.1.0" +version = "4.1.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index c7451cc5..9ce28805 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -43,8 +43,7 @@ AbstractMCMC.MCMCSerial ## Common keyword arguments -Common keyword arguments for regular and parallel sampling (not supported by the iterator and transducer) -are: +Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then @@ -53,6 +52,9 @@ are: - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +!!! info + The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). + There is no "official" way for providing initial parameter values yet. However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): diff --git a/src/sample.jl b/src/sample.jl index 01548470..3b578020 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -120,7 +120,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for i in 1:(discard_initial - 1) + for i in 1:discard_initial # Update the progress bar. if progress && i >= next_update ProgressLogging.@logprogress i / Ntotal @@ -218,7 +218,7 @@ function mcmcsample( sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. - for _ in 2:discard_initial + for _ in 1:discard_initial # Obtain the next sample and state. sample, state = step(rng, model, sampler, state; kwargs...) end diff --git a/src/stepper.jl b/src/stepper.jl index 18867c58..e7c97eed 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -5,9 +5,37 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} kwargs::K end -Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...) +# Initial sample. +function Base.iterate(stp::Stepper) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return sample, state +end + +# Subsequent samples. function Base.iterate(stp::Stepper, state) - return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...) + # Unpack iterator. + rng = stp.rng + model = stp.model + sampler = stp.sampler + kwargs = stp.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + # Return next sample, possibly after thinning the chain if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + return step(rng, model, sampler, state; kwargs...) end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() diff --git a/src/transducer.jl b/src/transducer.jl index 51f9b358..42df6dba 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -40,24 +40,58 @@ function Sample( return Sample(rng, model, sampler, kwargs) end +# Initial sample. function Transducers.start(rf::Transducers.R_{<:Sample}, result) - sampler = Transducers.xform(rf) + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + discard_initial = get(kwargs, :discard_initial, 0)::Int + + # Start sampling algorithm and discard initial samples if desired. + sample, state = step(rng, model, sampler; kwargs...) + for _ in 1:discard_initial + sample, state = step(rng, model, sampler, state; kwargs...) + end + return Transducers.wrap( - rf, - step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...), - Transducers.start(Transducers.inner(rf), result), + rf, (sample, state), Transducers.start(Transducers.inner(rf), result) ) end +# Subsequent samples. function Transducers.next(rf::Transducers.R_{<:Sample}, result, input) - t = Transducers.xform(rf) - Transducers.wrapping(rf, result) do (sample, state), iresult - iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample) - return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2 + # Unpack transducer. + td = Transducers.xform(rf) + rng = td.rng + model = td.model + sampler = td.sampler + kwargs = td.kwargs + thinning = get(kwargs, :thinning, 1)::Int + + let rng = rng, + model = model, + sampler = sampler, + kwargs = kwargs, + thinning = thinning, + inner_rf = Transducers.inner(rf) + + Transducers.wrapping(rf, result) do (sample, state), iresult + iresult2 = Transducers.next(inner_rf, iresult, sample) + + # Perform thinning if desired. + for _ in 1:(thinning - 1) + _, state = step(rng, model, sampler, state; kwargs...) + end + + return step(rng, model, sampler, state; kwargs...), iresult2 + end end end function Transducers.complete(rf::Transducers.R_{Sample}, result) - _private_state, inner_result = Transducers.unwrap(rf, result) + _, inner_result = Transducers.unwrap(rf, result) return Transducers.complete(Transducers.inner(rf), inner_result) end diff --git a/test/sample.jl b/test/sample.jl index f5a69c12..cf080321 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -137,6 +137,7 @@ @test chains isa Vector{<:MyChain} @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) + @test all(ismissing(x.as[1]) for x in chains) # test some statistical properties @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @@ -147,9 +148,9 @@ # test reproducibility Random.seed!(1234) chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(x.as[1]) for x in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -245,7 +246,7 @@ # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -260,9 +261,9 @@ chains2 = sample( MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain ) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -330,7 +331,7 @@ # Test output type and size. @test chains isa Vector{<:MyChain} - @test all(c.as[1] === missing for c in chains) + @test all(ismissing(c.as[1]) for c in chains) @test length(chains) == 1000 @test all(x -> length(x.as) == length(x.bs) == N, chains) @@ -343,9 +344,9 @@ # Test reproducibility. Random.seed!(1234) chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) - - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(ismissing(c.as[1]) for c in chains2) + @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) + @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" @@ -415,6 +416,7 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_serial) # Multi-threaded sampling Random.seed!(1234) @@ -427,12 +429,13 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_threads) @test all( - c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), - i in 1:N + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 2:N ) @test all( - c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N ) @@ -447,12 +450,13 @@ progress=false, chain_type=MyChain, ) + @test all(ismissing(c.as[1]) for c in chains_distributed) @test all( - c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), - i in 1:N + c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 2:N ) @test all( - c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), + c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N ) end @@ -473,24 +477,41 @@ end @testset "Discard initial samples" begin - chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50) - @test length(chain) == 100 + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + discard_initial = 50 + chain = sample(MyModel(), MySampler(), N; discard_initial=discard_initial) + @test length(chain) == N @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + discard_initial; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + discard_initial].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. - Random.seed!(1234) + Random.seed!(100) N = 100 thinning = 3 - chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning) + chain = sample(MyModel(), MySampler(), N; thinning=thinning) @test length(chain) == N @test ismissing(chain[1].a) # Repeat sampling without thinning. - Random.seed!(1234) - ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true) - @test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N) + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(100) + ref_chain = sample(MyModel(), MySampler(), N * thinning; progress=VERSION < v"1.6") + @test all(chain[i].a == ref_chain[(i - 1) * thinning + 1].a for i in 2:N) + @test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N) end @testset "Sample without predetermined N" begin @@ -501,16 +522,44 @@ @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. - chain = sample(MyModel(), MySampler(); discard_initial=50) + Random.seed!(1234) + discard_initial = 50 + chain = sample(MyModel(), MySampler(); discard_initial=discard_initial) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), + MySampler(), + N; + discard_initial=discard_initial, + progress=VERSION < v"1.6", + ) + @test all(chain[i].a == ref_chain[i].a for i in 1:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) + # Thin chain by a factor of `thinning`. - chain = sample(MyModel(), MySampler(); thinning=3) + Random.seed!(1234) + thinning = 3 + chain = sample(MyModel(), MySampler(); thinning=thinning) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 + + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + N = length(chain) + ref_chain = sample( + MyModel(), MySampler(), N; thinning=thinning, progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i].a for i in 2:N) + @test all(chain[i].b == ref_chain[i].b for i in 1:N) end @testset "Sample vector of `NamedTuple`s" begin diff --git a/test/stepper.jl b/test/stepper.jl index 1b570557..bc0ea8b2 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -29,4 +29,46 @@ @test Base.IteratorSize(iter) == Base.IsInfinite() @test Base.IteratorEltype(iter) == Base.EltypeUnknown() end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + iter = AbstractMCMC.steps(MyModel(), MySampler(); discard_initial=discard_initial) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + iter = AbstractMCMC.steps(MyModel(), MySampler(); thinning=thinning) + as = [] + bs = [] + for t in Iterators.take(iter, N) + push!(as, t.a) + push!(bs, t.b) + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end end diff --git a/test/transducer.jl b/test/transducer.jl index f9e1a049..c534ac90 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -49,4 +49,50 @@ @test mean(bs) ≈ 0.0 atol = 5e-2 @test var(bs) ≈ 1 atol = 5e-2 end + + @testset "Discard initial samples" begin + # Create a chain of `N` samples after discarding some initial samples. + Random.seed!(1234) + N = 50 + discard_initial = 10 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); discard_initial=discard_initial), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample( + MyModel(), MySampler(), N; discard_initial=discard_initial, progress=false + ) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end + + @testset "Thin chain by a factor of `thinning`" begin + # Create a thinned chain with a thinning factor of `thinning`. + Random.seed!(1234) + N = 50 + thinning = 3 + xf = opcompose( + AbstractMCMC.Sample(MyModel(), MySampler(); thinning=thinning), + Map(x -> (x.a, x.b)), + ) + as, bs = foldl(xf, 1:N; init=([], [])) do (as, bs), (a, b) + push!(as, a) + push!(bs, b) + as, bs + end + + # Repeat sampling with `sample`. + Random.seed!(1234) + chain = sample(MyModel(), MySampler(), N; thinning=thinning, progress=false) + @test all(as[i] === chain[i].a for i in 1:N) + @test all(bs[i] === chain[i].b for i in 1:N) + end end