diff --git a/Project.toml b/Project.toml index 2efb10e4..29510045 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.12.0" +version = "4.13.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/docs/Project.toml b/docs/Project.toml index 8991c191..d1aa2969 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,7 +3,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004" -MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" @@ -12,6 +12,7 @@ CategoricalArrays = "0.8, 0.9, 0.10" DataFrames = "0.22, 1" Documenter = "0.26, 0.27" Gadfly = "1.3" -MLJModels = "0.14" +MLJBase = "0.18" MLJXGBoostInterface = "0.1" StatsPlots = "0.14" +julia = "1.3" diff --git a/src/chains.jl b/src/chains.jl index d216a04e..0db00f21 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -31,9 +31,22 @@ function Chains( name_map = (parameters = parameter_names,); start::Int = 1, thin::Int = 1, + iterations::AbstractVector{Int} = range(start; step=thin, length=size(val, 1)), evidence = missing, info::NamedTuple = NamedTuple() ) + # Check that iteration numbers are reasonable + if length(iterations) != size(val, 1) + error("length of `iterations` (", length(iterations), + ") is not equal to the number of iterations (", size(val, 1), ")") + end + if !isempty(iterations) && first(iterations) < 1 + error("iteration numbers must be positive integers") + end + if !isstrictlyincreasing(iterations) + error("iteration numbers must be strictly increasing") + end + # Make sure that we have a `:parameters` index and # Copying can avoid state mutation. _name_map = initnamemap(name_map) @@ -58,7 +71,7 @@ function Chains( # Construct the AxisArray. arr = AxisArray(val; - iter = range(start, step=thin, length=size(val, 1)), + iter = iterations, var = parameter_names, chain = 1:size(val, 3)) @@ -444,17 +457,21 @@ Return the range of iteration indices of the `chains`. Base.range(chains::Chains) = chains.value[Axis{:iter}].val """ - setrange(chains::Chains, range) + setrange(chains::Chains, range::AbstractVector{Int}) Generate a new chain from `chains` with iterations indexed by `range`. The new chain and `chains` share the same data in memory. """ -function setrange(chains::Chains, range::AbstractRange{<:Integer}) +function setrange(chains::Chains, range::AbstractVector{Int}) if length(chains) != length(range) error("length of `range` (", length(range), ") is not equal to the number of iterations (", length(chains), ")") end + if !isempty(range) && first(range) < 1 + error("iteration numbers must be positive integers") + end + isstrictlyincreasing(range) || error("iteration numbers must be strictly increasing") value = AxisArray(chains.value.data; iter = range, var = names(chains), chain = MCMCChains.chains(chains)) @@ -574,8 +591,7 @@ function header(c::Chains; section=missing) # Return header. return string( ismissing(c.logevidence) ? "" : "Log evidence = $(c.logevidence)\n", - "Iterations = $(first(c)):$(last(c))\n", - "Thinning interval = $(step(c))\n", + "Iterations = $(range(c))\n", "Number of chains = $(size(c, 3))\n", "Samples per chain = $(length(range(c)))\n", ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n", @@ -725,8 +741,11 @@ _cat(dim::Int, cs::Chains...) = _cat(Val(dim), cs...) function _cat(::Val{1}, c1::Chains, args::Chains...) # check inputs - thin = step(c1) - all(c -> step(c) == thin, args) || throw(ArgumentError("chain thinning differs")) + lastiter = last(c1) + for c in args + first(c) > lastiter || throw(ArgumentError("iterations have to be sorted")) + lastiter = last(c) + end nms = names(c1) all(c -> names(c) == nms, args) || throw(ArgumentError("chain names differ")) chns = chains(c1) @@ -735,7 +754,7 @@ function _cat(::Val{1}, c1::Chains, args::Chains...) # concatenate all chains data = mapreduce(c -> c.value.data, vcat, args; init = c1.value.data) value = AxisArray(data; - iter = range(first(c1); length = size(data, 1), step = thin), + iter = mapreduce(range, vcat, args; init=range(c1)), var = nms, chain = chns) diff --git a/src/fileio.jl b/src/fileio.jl index e294df21..f36ebee5 100644 --- a/src/fileio.jl +++ b/src/fileio.jl @@ -22,5 +22,5 @@ function readcoda(output::AbstractString, index::AbstractString) value[:, i] = out[inds, 2] end - Chains(value, start=first(window), thin=step(window), names=names) + Chains(value; iterations=window, names=names) end diff --git a/src/rstar.jl b/src/rstar.jl index 3f467871..c0c921b3 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -18,12 +18,11 @@ verbosity level. # Example ```jldoctest rstar; output = false, filter = r".*"s -using MLJModels +using MLJBase, MLJXGBoostInterface -XGBoost = @load XGBoostClassifier verbosity=0 chn = Chains(fill(4, 100, 2, 3)) -Rs = rstar(XGBoost(), chn; iterations=20) +Rs = rstar(XGBoostClassifier(), chn; iterations=20) R = round(mean(Rs); digits=0) # output diff --git a/src/utils.jl b/src/utils.jl index 5c54f242..7ffad8bc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -196,3 +196,17 @@ function concretize(x::Chains) return Chains(concretize(value), x.logevidence, x.name_map, x.info) end end + +function isstrictlyincreasing(x::AbstractVector{Int}) + return isempty(x) || _isstrictlyincreasing_nonempty(x) +end + +_isstrictlyincreasing_nonempty(x::AbstractRange{Int}) = step(x) > 0 +function _isstrictlyincreasing_nonempty(x::AbstractVector{Int}) + i = first(x) + for j in Iterators.drop(x, 1) + j > i || return false + i = j + end + return true +end diff --git a/test/concatenation_tests.jl b/test/concatenation_tests.jl index c34f45f7..513f4e3b 100644 --- a/test/concatenation_tests.jl +++ b/test/concatenation_tests.jl @@ -46,58 +46,60 @@ end chn = Chains(rand(10, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["d", "e"])) chn1 = Chains(rand(5, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["a", "b"])) - # incorrect thinning - @test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2); thin = 2)) + # incorrect iterations + @test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2))) # incorrect names - @test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"])) + @test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]; start=11)) # incorrect number of chains - @test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"])) + @test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]; start=11)) # concate the same chain - chn2 = vcat(chn, chn) + chn_shifted = setrange(chn, 11:20) + chn2 = vcat(chn, chn_shifted) @test chn2.value.data == vcat(chn.value.data, chn.value.data) @test size(chn2) == (20, 5, 2) @test names(chn2) == names(chn) @test range(chn2) == 1:20 @test chn2.name_map == (parameters = [:a, :b, :c], internal = [:d, :e]) - - chn2a = cat(chn, chn) + + chn2a = cat(chn, chn_shifted) @test chn2a.value == chn2.value @test chn2a.name_map == chn2.name_map @test chn2a.info == chn2.info - chn2b = cat(chn, chn; dims = Val(1)) + chn2b = cat(chn, chn_shifted; dims = Val(1)) @test chn2b.value == chn2.value @test chn2b.name_map == chn2.name_map @test chn2b.info == chn2.info - chn2c = cat(chn, chn; dims = 1) + chn2c = cat(chn, chn_shifted; dims = 1) @test chn2c.value == chn2.value @test chn2c.name_map == chn2.name_map @test chn2c.info == chn2.info # concatenate a different chain - chn3 = vcat(chn, chn1) + chn1_shifted = setrange(chn1, 11:15) + chn3 = vcat(chn, chn1_shifted) @test chn3.value.data == vcat(chn.value.data, chn1.value.data) @test size(chn3) == (15, 5, 2) @test names(chn3) == names(chn) @test range(chn3) == 1:15 # just take the name map of first argument @test chn3.name_map == (parameters = [:a, :b, :c], internal = [:d, :e]) - - chn3a = cat(chn, chn1) + + chn3a = cat(chn, chn1_shifted) @test chn3a.value == chn3.value @test chn3a.name_map == chn3.name_map @test chn3a.info == chn3.info - chn3b = cat(chn, chn1; dims = Val(1)) + chn3b = cat(chn, chn1_shifted; dims = Val(1)) @test chn3b.value == chn3.value @test chn3b.name_map == chn3.name_map @test chn3b.info == chn3.info - chn3c = cat(chn, chn1; dims = 1) + chn3c = cat(chn, chn1_shifted; dims = 1) @test chn3c.value == chn3.value @test chn3c.name_map == chn3.name_map @test chn3c.info == chn3.info diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl index 5b51418e..4f995ddc 100644 --- a/test/diagnostic_tests.jl +++ b/test/diagnostic_tests.jl @@ -15,6 +15,12 @@ val = hcat(val, rand(1:2, niter, 1, nchains)) # construct a Chains object chn = Chains(val, start = 1, thin = 2) +@test_throws ErrorException Chains(val; start=0, thin=2) +@test_throws ErrorException Chains(val; start=niter, thin=-1) +@test_throws ErrorException Chains(val; iterations=1:(niter - 1)) +@test_throws ErrorException Chains(val; iterations=range(0; step=2, length=niter)) +@test_throws ErrorException Chains(val; iterations=niter:-1:1) +@test_throws ErrorException Chains(val; iterations=ones(Int, niter)) # Chains object for discretediag val_disc = rand(Int16, 200, nparams, nchains) @@ -29,18 +35,26 @@ chn_disc = Chains(val_disc, start = 1, thin = 2) @test keys(chn) == names(chn) == [:param_1, :param_2, :param_3, :param_4] @test range(chn) == range(1; step = 2, length = niter) + @test range(chn) == range(Chains(val; iterations=range(chn))) + @test range(chn) == range(Chains(val; iterations=collect(range(chn)))) @test_throws ErrorException setrange(chn, 1:10) + @test_throws ErrorException setrange(chn, 0:(niter - 1)) + @test_throws ErrorException setrange(chn, niter:-1:1) + @test_throws ErrorException setrange(chn, ones(Int, niter)) @test_throws MethodError setrange(chn, float.(range(chn))) - chn2 = setrange(chn, range(1; step = 10, length = niter)) - @test range(chn2) == range(1; step = 10, length = niter) - @test names(chn2) === names(chn) - @test chains(chn2) === chains(chn) - @test chn2.value.data === chn.value.data - @test chn2.logevidence === chn.logevidence - @test chn2.name_map === chn.name_map - @test chn2.info == chn.info + chn2a = setrange(chn, range(1; step = 10, length = niter)) + chn2b = setrange(chn, collect(range(1; step = 10, length = niter))) + for chn2 in (chn2a, chn2b) + @test range(chn2) == range(1; step = 10, length = niter) + @test names(chn2) === names(chn) + @test chains(chn2) === chains(chn) + @test chn2.value.data === chn.value.data + @test chn2.logevidence === chn.logevidence + @test chn2.name_map === chn.name_map + @test chn2.info == chn.info + end chn3 = resetrange(chn) @test range(chn3) == 1:niter diff --git a/test/rstar_tests.jl b/test/rstar_tests.jl index fded4784..cc4f7dd4 100644 --- a/test/rstar_tests.jl +++ b/test/rstar_tests.jl @@ -1,5 +1,6 @@ using MCMCChains -using MLJModels +using MLJBase +using MLJXGBoostInterface using Test N = 1000 @@ -8,8 +9,7 @@ colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] internal_colnames = ["c", "d", "e", "f", "g", "h"] chn = Chains(val, colnames, Dict(:internals => internal_colnames)) -XGBoost = @load XGBoostClassifier -classif = XGBoost() +classif = XGBoostClassifier() @testset "R star test" begin # Compute R* statistic for a mixed chain. diff --git a/test/runtests.jl b/test/runtests.jl index dd88c889..568edc95 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ Random.seed!(0) if VERSION >= v"1.3" && Sys.WORD_SIZE == 64 # run tests related to rstar statistic println("Rstar") - Pkg.add("MLJModels") + Pkg.add("MLJBase") Pkg.add("MLJXGBoostInterface") @time include("rstar_tests.jl")