diff --git a/test/sample.jl b/test/sample.jl index 00f7ccae..b07387c7 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -292,11 +292,16 @@ @testset "Testing callbacks" begin function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) - iter_array[i] = i + push!(iter_array, i) end N = 100 - it_array = zeros(N) + it_array = Float64[] sample(MyModel(), MySampler(), N; callback=count_iterations, iter_array=it_array) @test it_array == collect(1:N) + + # sampling without predetermined N + it_array = Float64[] + chain = sample(MyModel(), MySampler(); callback=count_iterations, iter_array=it_array) + @test it_array == collect(1:size(chain, 1)) end end