diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index b7d0ff94..4a43de1e 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -37,7 +37,7 @@ Parameters: The `info` field can be set using `setinfo(c::Chains, n::NamedTuple)`. """ struct Chains{A, T, K<:NamedTuple, L<:NamedTuple} <: AbstractChains - value::AxisArray{Union{Missing,A},3} + value::AxisArray{A,3} logevidence::T name_map::K info::L diff --git a/src/chains.jl b/src/chains.jl index f4b12634..b02fb010 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -90,7 +90,7 @@ function Chains(val::AbstractArray{A,3}, # Construct the AxisArray. axs = ntuple(i -> Axis{names[i]}(axvals[i]), 3) - arr = AxisArray(convert(Array{Union{Missing,A},3}, val), axs...) + arr = AxisArray(val, axs...) return sort( Chains{A, typeof(evidence), typeof(name_map_tupl), typeof(info)}( arr, evidence, name_map_tupl, info) @@ -168,8 +168,8 @@ end Base.getindex(c::Chains, i1::T) where T<:Union{AbstractUnitRange, StepRange} = c[i1, :, :] Base.getindex(c::Chains, i1::Integer) = c[i1:i1, :, :] Base.getindex(c::Chains, v::Symbol) = c[[v]] -Base.getindex(c::Chains, v::String) = Array(c[:, [v], :]) -Base.getindex(c::Chains, v::Vector{String}) = Array(c[:, v, :]) +Base.getindex(c::Chains, v::String) = c[:, [v], :] +Base.getindex(c::Chains, v::Vector{String}) = c[:, v, :] function Base.getindex(c::Chains, v::Vector{Symbol}) syms = _sym2index(c, v) @@ -198,6 +198,8 @@ function Base.getindex(c::Chains{A, T, K, L}, i...) where {A, T, K, L} end Base.setindex!(c::Chains, v, i...) = setindex!(c.value, v, i...) +Base.lastindex(c::Chains) = lastindex(c.value, 1) +Base.lastindex(c::Chains, d::Integer) = lastindex(c.value, d) """ Base.get(c::Chains, v::Symbol; flatten=false)