Skip to content

Commit

Permalink
Include state and kwargs... to callback (#56)
Browse files Browse the repository at this point in the history
* Include `state` and `kwargs...` to `callback`

* Update README.md

* Add test for testing callbacks do the right thing.

* Update Project.toml

Co-authored-by: Cameron Pfiffer <[email protected]>
  • Loading branch information
theogf and cpfiffer authored Apr 7, 2021
1 parent 3263028 commit d961513
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "2.5.0"
version = "3.0.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
`callback(rng, model, sampler, sample, state, iteration; kwargs...)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.
Expand Down
4 changes: 2 additions & 2 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function mcmcsample(
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, 1)
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
Expand Down Expand Up @@ -140,7 +140,7 @@ function mcmcsample(
sample, state = step(rng, model, sampler, state; kwargs...)

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, i)
callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
Expand Down
10 changes: 10 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,14 @@
@test mean(x.b for x in chain) 0 atol=0.1
@test var(x.b for x in chain) 1 atol=0.15
end

@testset "Testing callbacks" begin
function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...)
iter_array[i] = i
end
N = 100
it_array = zeros(N)
sample(MyModel(), MySampler(), N; callback=count_iterations, iter_array=it_array)
@test it_array == collect(1:N)
end
end

2 comments on commit d961513

@theogf
Copy link
Member Author

@theogf theogf commented on d961513 Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33735

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.0.0 -m "<description of version>" d961513b0ad3dc53d414da6581b332b66c4d0ef3
git push origin v3.0.0

Please sign in to comment.