diff --git a/Project.toml b/Project.toml index 477ccc9d..e536a612 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.0.0" +version = "3.0.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 563d4652..3bf55083 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -197,7 +197,7 @@ function mcmcsample( end # Run callback. - callback === nothing || callback(rng, model, sampler, sample, 1) + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) @@ -217,7 +217,7 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, i) + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) diff --git a/test/sample.jl b/test/sample.jl index 00f7ccae..b07387c7 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -292,11 +292,16 @@ @testset "Testing callbacks" begin function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) - iter_array[i] = i + push!(iter_array, i) end N = 100 - it_array = zeros(N) + it_array = Float64[] sample(MyModel(), MySampler(), N; callback=count_iterations, iter_array=it_array) @test it_array == collect(1:N) + + # sampling without predetermined N + it_array = Float64[] + chain = sample(MyModel(), MySampler(); callback=count_iterations, iter_array=it_array) + @test it_array == collect(1:size(chain, 1)) end end