Skip to content

Commit

Permalink
Add discard_initial keyword argument (#47)
Browse files Browse the repository at this point in the history
* Add `discard_initial` keyword argument

* Update README

* Update Transducer syntax

* Bump version
  • Loading branch information
devmotion authored Aug 22, 2020
1 parent 9e77a24 commit faf1d73
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "2.0.0"
version = "2.1.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ Common keyword arguments for regular and parallel sampling (not supported by the
are:
- `progress` (default: `true`): toggles progress logging
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nohting`, then
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- `discard_initial` (default: `0`): number of initial samples that are discarded

Additionally, AbstractMCMC defines the abstract type `AbstractChains` for Markov chains and the
method `AbstractMCMC.chainscat(::AbstractChains...)` for concatenating multiple chains.
Expand Down
22 changes: 20 additions & 2 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,27 @@ function mcmcsample(
progress = true,
progressname = "Sampling",
callback = nothing,
discard_initial = 0,
chain_type::Type=Any,
kwargs...
)
# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
Ntotal = N + discard_initial

@ifwithprogresslogger progress name=progressname begin
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for i in 1:(discard_initial - 1)
# Update the progress bar.
progress && ProgressLogging.@logprogress i/Ntotal

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, 1)

Expand All @@ -82,7 +93,7 @@ function mcmcsample(
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)

# Update the progress bar.
progress && ProgressLogging.@logprogress 1/N
progress && ProgressLogging.@logprogress (1 + discard_initial) / Ntotal

# Step through the sampler.
for i in 2:N
Expand All @@ -96,7 +107,7 @@ function mcmcsample(
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
progress && ProgressLogging.@logprogress i/N
progress && ProgressLogging.@logprogress (i + discard_initial) / Ntotal
end
end

Expand Down Expand Up @@ -129,12 +140,19 @@ function mcmcsample(
progress = true,
progressname = "Convergence sampling",
callback = nothing,
discard_initial = 0,
kwargs...
)
@ifwithprogresslogger progress name=progressname begin
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for _ in 2:discard_initial
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, 1)

Expand Down
13 changes: 13 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,23 @@
@test chain2 isa MyChain
end

@testset "Discard initial samples" begin
chain = sample(MyModel(), MySampler(), 100; sleepy = true, discard_initial = 50)
@test length(chain) == 100
@test !ismissing(chain[1].a)
end

@testset "Sample without predetermined N" begin
Random.seed!(1234)
chain = sample(MyModel(), MySampler())
bmean = mean(x.b for x in chain)
@test ismissing(chain[1].a)
@test abs(bmean) <= 0.001 && length(chain) < 10_000

# Discard initial samples.
chain = sample(MyModel(), MySampler(); discard_initial = 50)
bmean = mean(x.b for x in chain)
@test !ismissing(chain[1].a)
@test abs(bmean) <= 0.001 && length(chain) < 10_000
end

Expand Down
11 changes: 7 additions & 4 deletions test/transducer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Logging.with_logger(TerminalLogger()) do
xf = AbstractMCMC.Sample(MyModel(), MySampler();
sleepy = true, logger = true)
chain = collect(xf, withprogress(1:N; interval=1e-3))
chain = withprogress(1:N; interval=1e-3) |> xf |> collect
end

# test output type and size
Expand All @@ -24,16 +24,19 @@

@testset "drop" begin
xf = AbstractMCMC.Sample(MyModel(), MySampler())
chain = collect(xf |> Drop(1), 1:10)
chain = 1:10 |> xf |> Drop(1) |> collect
@test chain isa Vector{MySample{Float64,Float64}}
@test length(chain) == 9
end

# Reproduce iterator example
@testset "iterator example" begin
# filter missing values and split transitions
xf = AbstractMCMC.Sample(MyModel(), MySampler()) |>
OfType(MySample{Float64,Float64}) |> Map(x -> (x.a, x.b))
xf = opcompose(
AbstractMCMC.Sample(MyModel(), MySampler()),
OfType(MySample{Float64,Float64}),
Map(x -> (x.a, x.b)),
)
as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b)
push!(as, a)
push!(bs, b)
Expand Down

2 comments on commit faf1d73

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/19958

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.1.0 -m "<description of version>" faf1d73433734f53bb24452ae273401e0969f799
git push origin v2.1.0

Please sign in to comment.