Skip to content

Commit

Permalink
Merge pull request #63 from TuringLang/csp/misc-fixes
Browse files Browse the repository at this point in the history
Miscellaneous fixes
  • Loading branch information
cpfiffer authored Mar 10, 2019
2 parents 400aa57 + b4ceb96 commit e83df25
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Parameters:
The `info` field can be set using `setinfo(c::Chains, n::NamedTuple)`.
"""
struct Chains{A, T, K<:NamedTuple, L<:NamedTuple} <: AbstractChains
value::AxisArray{Union{Missing,A},3}
value::AxisArray{A,3}
logevidence::T
name_map::K
info::L
Expand Down
8 changes: 5 additions & 3 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function Chains(val::AbstractArray{A,3},

# Construct the AxisArray.
axs = ntuple(i -> Axis{names[i]}(axvals[i]), 3)
arr = AxisArray(convert(Array{Union{Missing,A},3}, val), axs...)
arr = AxisArray(val, axs...)
return sort(
Chains{A, typeof(evidence), typeof(name_map_tupl), typeof(info)}(
arr, evidence, name_map_tupl, info)
Expand Down Expand Up @@ -168,8 +168,8 @@ end
Base.getindex(c::Chains, i1::T) where T<:Union{AbstractUnitRange, StepRange} = c[i1, :, :]
Base.getindex(c::Chains, i1::Integer) = c[i1:i1, :, :]
Base.getindex(c::Chains, v::Symbol) = c[[v]]
Base.getindex(c::Chains, v::String) = Array(c[:, [v], :])
Base.getindex(c::Chains, v::Vector{String}) = Array(c[:, v, :])
Base.getindex(c::Chains, v::String) = c[:, [v], :]
Base.getindex(c::Chains, v::Vector{String}) = c[:, v, :]

function Base.getindex(c::Chains, v::Vector{Symbol})
syms = _sym2index(c, v)
Expand Down Expand Up @@ -198,6 +198,8 @@ function Base.getindex(c::Chains{A, T, K, L}, i...) where {A, T, K, L}
end

Base.setindex!(c::Chains, v, i...) = setindex!(c.value, v, i...)
Base.lastindex(c::Chains) = lastindex(c.value, 1)
Base.lastindex(c::Chains, d::Integer) = lastindex(c.value, d)

"""
Base.get(c::Chains, v::Symbol; flatten=false)
Expand Down

0 comments on commit e83df25

Please sign in to comment.