Skip to content

Commit

Permalink
Make MCMCChains better (#181)
Browse files Browse the repository at this point in the history
* Make MCMCChains better

* Add modelstats.jl back.

* Address comments, fix test

* Fix bug, remove deprecated keyword.

* Change tolerance for diff test

* Reduce test accuracy

* Set seed for missing tests.

* Move DataFrames to [extra]

* Address comments
  • Loading branch information
cpfiffer authored Feb 24, 2020
1 parent d626082 commit a075207
Show file tree
Hide file tree
Showing 19 changed files with 374 additions and 344 deletions.
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = " Chain types and utility functions for MCMC simulations."
version = "2.0.0"
version = "3.0.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -21,9 +22,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
AbstractMCMC = "0.4"
AxisArrays = "^0.3, 0.4"
DataFrames = "^0.19, 0.20"
Distributions = "^0.21, 0.22"
Formatting = "0.4"
RecipesBase = "^0.7, 0.8"
Requires = "0.5, 1.0"
SpecialFunctions = "^0.8, 0.9, 0.10"
StatsBase = "^0.32"
julia = "^1"
Expand All @@ -32,6 +34,7 @@ julia = "^1"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

[targets]
test = ["KernelDensity", "StatsPlots", "Test"]
test = ["KernelDensity", "StatsPlots", "Test", "DataFrames"]
21 changes: 12 additions & 9 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,27 @@ using AxisArrays
const axes = Base.axes

import AbstractMCMC
using AbstractMCMC: chainscat
import DataFrames
using DataFrames: eachcol, DataFrame
import AbstractMCMC: chainscat
using Distributions
using RecipesBase
using SpecialFunctions
using StatsBase: autocov, counts, sem, AbstractWeights
import StatsBase: autocor, describe, quantile, sample, summarystats
using Formatting
import StatsBase: autocov, counts, sem, AbstractWeights,
autocor, describe, quantile, sample, summarystats, cov
using Requires

using LinearAlgebra: diag
import Serialization: serialize, deserialize
import Random
import Statistics: std, cor, mean
import Statistics: std, cor, mean, var

export Chains, chains, chainscat
export set_section, get_params, sections, sort_sections, setinfo, set_names
export mean
export autocor, describe, sample, summarystats, AbstractWeights
export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile
export ChainDataFrame, DataFrame
export summarize

# export diagnostics functions
# Export diagnostics functions
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

Expand Down Expand Up @@ -63,4 +62,8 @@ include("stats.jl")
include("modelstats.jl")
include("plot.jl")

function __init__()
@require DataFrames="a93c6f00-e57d-5684-b7b6-d8193f3e46c0" include("dataframes-compat.jl")
end

end # module
42 changes: 2 additions & 40 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ function Chains(
# Make the name_map NamedTuple.
name_map_tupl = _dict2namedtuple(name_map)

# Ensure that we have a hashedsummary key in info.
if !in(:hashedsummary, keys(info))
empty_df_vec = [ChainDataFrame("", DataFrame())]
s = (hash(0), empty_df_vec)
info = merge(info, (hashedsummary = Ref(s),))
end

# Construct the AxisArray.
arr = AxisArray(val;
iter = range(start, step=thin, length=size(val, 1)),
Expand Down Expand Up @@ -292,21 +285,8 @@ function Base.show(io::IO, c::Chains)
print(io, "Object of type Chains, with data of type $(summary(c.value.data))\n\n")
println(io, header(c))

# Grab the value hash.
h = hash(c)

if :hashedsummary in keys(c.info)
s = c.info.hashedsummary.x
if s[1] == h
show(io, s[2])
else
new_summary = describe(c)
c.info.hashedsummary.x = (h, new_summary)
show(io, new_summary)
end
else
show(io, describe(c, suppress_header=true))
end
# Show summary stats.
show(io, describe(c))
end

Base.keys(c::Chains) = names(c)
Expand All @@ -321,24 +301,6 @@ Base.convert(::Type{Array}, chn::Chains) = convert(Array, chn.value)

#################### Auxilliary Functions ####################

function Base.hash(c::Chains)
val = hash(c.value) + hash(c.info) + hash(c.name_map) + hash(c.logevidence)
return hash(val)
end

function combine(c::Chains)
n, p, m = size(c.value)
value = Array{Float64}(undef, n * m, p)
for j in 1:p
idx = 1
for i in 1:n, k in 1:m
value[idx, j] = c.value[i, j, k]
idx += 1
end
end
value
end

"""
range(c::Chains)
Expand Down
179 changes: 15 additions & 164 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,178 +62,29 @@ inclusion, a dimension is dropped in both cases, as is e.g. required by cde(), e
function Base.Array(chn::Chains,
sections::Union{Symbol, Vector{Symbol}}=Symbol[:parameters];
append_chains=true,
remove_missing_union=true,
showall=false,
sorted=false
)
sections = _clean_sections(chn, sections)
sections = sections isa AbstractArray ? sections : [sections]
sections = showall ? [] : sections
section_list = length(sections) == 0 ?
sort_sections(chn) :
sections
sections = showall ? keys(chn.name_map) : sections
chn = Chains(chn, sections)

# If we actually have missing values, we can't remove
# Union{Missing}.
remove_missing_union = remove_missing_union ?
all(x -> !ismissing(x), chn.value) :
remove_missing_union

d, p, c = size(chn.value.data)

local b
if append_chains
first_parameter = true
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if first_parameter
if remove_missing_union
b = reshape(convert(Array{Float64}, x[Symbol(par)]), d*c)[:, 1]
else
b = reshape(x[Symbol(par)], d*c)[:, 1]
end
p == 1 && (b = reshape(b, size(b, 1)))
first_parameter = false
else
if remove_missing_union
b = hcat(b, reshape(convert(Array{Float64}, x[Symbol(par)]), d*c)[:, 1])
else
b = hcat(b, reshape(x[Symbol(par)], d*c)[:, 1])
end
end
end
end
arr = if append_chains
mapreduce(i -> chn.value.data[:,:,i], vcat, 1:size(chn, 3))
else
b=Vector(undef, c)
for i in 1:c
first_parameter = true
for section in section_list
for par in chn.name_map[section]
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if first_parameter
if remove_missing_union
b[i] = convert(Array{Real}, x[Symbol(par)][:, i])
else
b[i] = x[Symbol(par)][:, i]
end
p == 1 && (b[i] = reshape(b[i], size(b[i], 1)))
first_parameter = false
else
if remove_missing_union
b[i] = hcat(b[i], convert(Array{Real}, x[Symbol(par)][:, i]))
else
b[i] = hcat(b[i], x[Symbol(par)][:, i])
end
end
end
end
end
map(i -> chn.value.data[:,:,i], 1:size(chn, 3))
end
return b
end

"""
# DataFrame
DataFrame constructor from a Chains object.
Returns either a DataFrame or an Array{DataFrame}
### Method
```julia
DataFrame(
chn::Chains,
sections::Vector{Symbo);
append_chains::Bool,
remove_missing_union::Bool
)
```
### Required arguments
```julia
* `chn` : Chains object to convert to an DataFrame
```
### Optional arguments
```julia
* `sections = Symbol[]` : Sections from the Chains object to be included
* `append_chains = true` : Append chains into a single column
* `remove_missing_union = true` : Remove Union{Missing, Real}
```
### Examples
```julia
* `DataFrame(chns)` : DataFrame with chain values are appended
* `DataFrame(chns[:par])` : DataFrame with single parameter (chain values are appended)
* `DataFrame(chns, [:parameters])` : DataFrame with only :parameter section
* `DataFrame(chns, [:parameters, :internals])` : DataFrame includes both sections
* `DataFrame(chns, append_chains=false)` : Array of DataFrames, each chain in its own array
* `DataFrame(chns, remove_missing_union=false)` : No conversion to remove missing values
```
"""
function DataFrames.DataFrame(chn::Chains,
sections::Union{Symbol, Vector{Symbol}}=Symbol[:parameters];
append_chains=true,
remove_missing_union=true,
sorted=false,
showall=false)
sections = _clean_sections(chn, sections)
sections = sections isa AbstractArray ? sections : [sections]
sections = showall ? [] : sections
section_list = length(sections) == 0 ? sort_sections(chn) : sections

# If we actually have missing values, we can't remove
# Union{Missing}.
remove_missing_union = remove_missing_union ?
all(!ismissing, chn.value) :
remove_missing_union

d, p, c = size(chn.value.data)
arr = if append_chains
reshape(arr, (size(arr, 1), size(arr, 2)))
else
arr
end

local b
if append_chains
b = DataFrame()
for section in section_list
names = sorted ?
sort(chn.name_map[section],
by=x->string(x), lt = natural) :
chn.name_map[section]
for par in names
x = get(chn, Symbol(par))
d, c = size(x[Symbol(par)])
if remove_missing_union
b = hcat(b, DataFrame(Symbol(par) => reshape(convert(Array{Float64},
x[Symbol(par)]), d*c)[:, 1]))
else
b = hcat(b, DataFrame(Symbol(par) => reshape(x[Symbol(par)], d*c)[:, 1]))
end
end
end
if size(arr, 2) == 1
return map(identity, arr[:,1])
else
b = Vector{DataFrame}(undef, c)
for i in 1:c
b[i] = DataFrame()
for section in section_list
names = sorted ?
sort(chn.name_map[section],
by=x->string(x), lt = natural) :
chn.name_map[section]
for par in names
x = get(chn[:,:,i], Symbol(par))
d, c = size(x[Symbol(par)])
if remove_missing_union
b[i] = hcat(b[i], DataFrame(Symbol(par) => convert(Array{Float64},
x[Symbol(par)])[:, 1]))
else
b[i] = hcat(b[i], DataFrame(Symbol(par) => x[Symbol(par)][:,1]))
end
end
end
end
return map(identity, arr)
end
return b
end


Loading

0 comments on commit a075207

Please sign in to comment.