From 50b13703d11c9ca1088036edd452a9d8c0538ff6 Mon Sep 17 00:00:00 2001 From: Rob J Goedman Date: Wed, 27 Mar 2019 10:21:12 -0600 Subject: [PATCH] Updated to README and docs. Simplified sampling_tests.jl by using new Array feature. --- README.md | 84 +++++++++++++++++++++++++++++------ src/constructors.jl | 2 - test/df_chainsummary_tests.jl | 36 --------------- test/sampling_tests.jl | 8 ++-- 4 files changed, 75 insertions(+), 55 deletions(-) delete mode 100644 test/df_chainsummary_tests.jl diff --git a/README.md b/README.md index 5df0c078..dd95338d 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ chn = Chains(val, Dict(:internals => ["d", "e"])) ``` -Or through the `set_section` function, which returns a new `Chains` object. `Chains` objects cannot be modified in place due to section map immutability: +Or through the `set_section` function, which returns a new `Chains` object (as `Chains` objects cannot be modified in place due to section map immutability): ```julia chn2 = set_section(chn, Dict(:internals => ["d", "e"])) @@ -152,18 +152,6 @@ You can access each of the `P[. . .]` variables by indexing, using `x.P[1]`, `x. Note that `x.P` is a tuple which has to be indexed by the relevant index, while `x.D` is just a vector. -### Saving and Loading Chains - -Chains objects can be serialized and deserialized using `read` and `write`. - -```julia -# Save a chain. -write("chain-file.jls", chn) - -# Read a chain. -chn2 = read("chain-file.jls", Chains) -``` - ### Convergence Diagnostics functions #### Discrete Diagnostic Options for method are `[:weiss, :hangartner, :DARBOOT, MCBOOT, :billinsgley, :billingsleyBOOT]` @@ -221,5 +209,75 @@ autocorplot(c::AbstractChains) corner(c::AbstractChains, [:A, :B]) ``` +### Saving and Loading Chains + +Chains objects can be serialized and deserialized using `read` and `write`. + +```julia +# Save a chain. +write("chain-file.jls", chn) + +# Read a chain. +chn2 = read("chain-file.jls", Chains) +``` + +### Exporting Chains + +A few utility export functions have been provided to convers `Chains` objects to either an Array or a DataFrame: + +```julia +# Several examples of creating an Array object: +Array(chns) +Array(chns[:s]) +Array(chns, [:parameters]) +Array(chns, [:parameters, :internals]) + +# By default chains are appended. This can be disabled +# using the append_chains keyword argument: +Array(chns, append_chains=false) + +# This will return an `Array{Array, 1}` object containing +# an Array for each chain. + +# A final option is: +Array(chns, remove_missing_union=false) + +# This will not convert the Array columns from a +`Union{Missing, Real}` to a `Vector{Real}`. +``` + +Similarly, for DataFrames: + +```julia +DataFrame(chns) +DataFrame(chns[:s]) +DataFrame(chns, [:parameters]) +DataFrame(chns, [:parameters, :internals]) +DataFrame(chns, append_chains=false) +DataFrame(chns, remove_missing_union=false) +``` + +See also ?MCMCChains.DataFrame and ?MCMCChains.Array for more help. + +### Sampling Chains + +MCMCChains overloads several `sample()` methods as defined in StatsBase: + +```julia +# Sampling `n` samples from the chain `a`. Optionally +# weighting the samples using `wv`. +sample([rng], a, [wv::AbstractWeights], n::Integer) + +# E.g. creating 10000 weighted samples: +c = kde(Array(chn[:s])) +chn_weighted_sample = sample(c.x, Weights(c.density), 100000) + +# As above, but supports replacing and ordering. +sample([rng], a, [wv::AbstractWeights], n::Integer; replace=true, + ordered=false) +``` + +See also ?MCMCChains.sample for additional help. + ## License Notice Note that this package heavily uses and adapts code from the Mamba.jl package licensed under MIT License, see License.md. diff --git a/src/constructors.jl b/src/constructors.jl index 1c062f4d..2246a741 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -120,8 +120,6 @@ function Array(chn::MCMCChains.AbstractChains, b end -Base.convert(::Type{Array}, chn::MCMCChains.Chains) = convert(Array, chn.value) - """ # DataFrame diff --git a/test/df_chainsummary_tests.jl b/test/df_chainsummary_tests.jl deleted file mode 100644 index 9ca95d8f..00000000 --- a/test/df_chainsummary_tests.jl +++ /dev/null @@ -1,36 +0,0 @@ -using Turing, MCMCChains, StatsBase - -#@testset "Dataframe summary" begin - - @model gdemo(x) = begin - m ~ Normal(1, 0.01) - s ~ Normal(5, 0.01) - end - - model = gdemo([1.5, 2.0]) - sampler = HMC(1000, 0.01, 5) - - chns = [sample(model, sampler, save_state=true) for i in 1:4] - chns = chainscat(chns...) - - d, p, c = size(chns.value.data) - - - describe(chns) - - display(StatsBase.summarystats(chns).summaries[1]) - pars = StatsBase.summarystats(chns).summaries[1].value - println() - - display(StatsBase.summarystats(chns, section=:internals).summaries[1]) - internals = StatsBase.summarystats(chns, section=:internals).summaries[1].value - - convert(::Type{Array}, chn::Chains) = chn.value.data - - chns.name_map - - arry = convert(Array, chns) - size(arry) - arry[1:5, :, 1:2] - - #end \ No newline at end of file diff --git a/test/sampling_tests.jl b/test/sampling_tests.jl index b9a17d9a..cd749706 100644 --- a/test/sampling_tests.jl +++ b/test/sampling_tests.jl @@ -1,6 +1,6 @@ using Turing, MCMCChains, KernelDensity, StatsBase, Test, Statistics -@testset "sampling api" begin +#@testset "sampling api" begin @model gdemo(x) = begin m ~ Normal(1, 0.01) @@ -14,9 +14,9 @@ using Turing, MCMCChains, KernelDensity, StatsBase, Test, Statistics chn_sample = sample(chn, 5) @test range(chn_sample) == 1:1:5 - c = kde(reshape(convert(Array{Float64}, chn[:s].value), 500)) + c = kde(Array(chn[:s])) chn_weighted_sample = sample(c.x, Weights(c.density), 100000) - @test mean(convert(Array{Float64}, chn[:s].value)) ≈ 5.0 atol=0.1 + @test mean(Array(chn[:s])) ≈ 5.0 atol=0.1 -end + #end