Skip to content

Commit

Permalink
Merge pull request #79 from TuringLang/dfchainsummary
Browse files Browse the repository at this point in the history
Use DataFrames for ChainSummaries
  • Loading branch information
cpfiffer authored Apr 16, 2019
2 parents 64efd3c + d3d9ec2 commit 6b693f5
Show file tree
Hide file tree
Showing 20 changed files with 886 additions and 727 deletions.
10 changes: 6 additions & 4 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import StatsBase: autocor, autocov, countmap, counts, describe, predict,
quantile, sample, sem, summarystats, sample, AbstractWeights
import LinearAlgebra: diag
import Serialization: serialize, deserialize
import Base: sort, range, names, get, hash, convert
import Base: sort, range, names, get, hash, convert, show, display
import Statistics: cor
import Core.Array
import DataFrames: DataFrame
import DataFrames: DataFrame, names

using RecipesBase
import RecipesBase: plot
Expand All @@ -24,10 +24,12 @@ export Chains, getindex, setindex!, chains, setinfo, chainscat
export describe, set_section, get_params, sections
export sample, AbstractWeights
export Array, DataFrame, sort_sections, convert
export hpd
export summarize, summarystats, ChainDataFrame
export hpd, ess

# export diagnostics functions
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export autocor

abstract type AbstractChains end

Expand All @@ -53,8 +55,8 @@ end
include("utils.jl")

include("chains.jl")
include("chainsummary.jl")
include("constructors.jl")
include("summarize.jl")
include("discretediag.jl")
include("fileio.jl")
include("gelmandiag.jl")
Expand Down
35 changes: 30 additions & 5 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function Chains(val::AbstractArray{A,3},

# Ensure that we have a hashedsummary key in info.
if !in(:hashedsummary, keys(info))
s = (hash(0), ChainSummaries("", []))
s = (hash(0), ChainDataFrame("", DataFrame()))
info = merge(info, (hashedsummary = Ref(s),))
end

Expand All @@ -105,9 +105,9 @@ function Chains(c::Chains{A, T, K, L}, section::Union{Vector, Any};
# If we received an empty list, return the original chain.
if isempty(section)
if sorted
return sort(new_chn)
return sort(c)
else
return new_chn
return c
end
end

Expand Down Expand Up @@ -314,7 +314,7 @@ function Base.show(io::IO, c::Chains)
if s[1] == h
show(io, s[2])
else
new_summary = summarystats(c, suppress_header=true)
new_summary = summarystats(c)
c.info.hashedsummary.x = (h, new_summary)
show(io, new_summary)
end
Expand Down Expand Up @@ -376,14 +376,34 @@ function chains(c::AbstractChains)
end

"""
names(c::AbstractChains)
names(c::AbstractChains, sections)
Return the parameter names in a `Chains` object.
"""
function names(c::AbstractChains)
return c.value[Axis{:var}].val
end

"""
names(c::AbstractChains, sections::Union{Symbol, Vector{Symbol}})
Return the parameter names in a `Chains` object, given an array of sections.
"""
function names(c::AbstractChains,
sections::Union{Symbol, Vector{Symbol}})
# Check that sections is an array.
sections = typeof(sections) <: AbstractArray ?
sections :
[sections]

nms = []

for i in sections
push!(nms, c.name_map[i]...)
end
return nms
end

"""
get_sections(c::AbstractChains, sections::Vector = [])
Expand Down Expand Up @@ -596,6 +616,11 @@ function _use_showall(c::AbstractChains, section::Symbol)
return false
end

function _clean_sections(c::AbstractChains, sections::Vector{Symbol})
ks = collect(keys(c.name_map))
return ks sections
end

#################### Concatenation ####################

function Base.cat(c1::AbstractChains, args::AbstractChains...; dims::Integer = 3)
Expand Down
127 changes: 0 additions & 127 deletions src/chainsummary.jl

This file was deleted.

Loading

0 comments on commit 6b693f5

Please sign in to comment.