diff --git a/Project.toml b/Project.toml index 9e2656e9..e1660b54 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.7.1" +version = "4.7.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/tables.jl b/src/tables.jl index 65cfbd3f..d5ad2e67 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -1,6 +1,8 @@ # Tables and TableTraits interface -## Chains +#### +#### Chains +#### function _check_columnnames(chn::Chains) for name in names(chn) @@ -11,8 +13,12 @@ function _check_columnnames(chn::Chains) end end +#### Tables interface + Tables.istable(::Type{<:Chains}) = true +# AbstractColumns interface + Tables.columnaccess(::Type{<:Chains}) = true function Tables.columns(chn::Chains) @@ -26,11 +32,11 @@ function Tables.getcolumn(chn::Chains, i::Int) return Tables.getcolumn(chn, Tables.columnnames(chn)[i]) end function Tables.getcolumn(chn::Chains, nm::Symbol) - if nm == :iteration + if nm === :iteration iterations = range(chn) nchains = size(chn, 3) return repeat(iterations, nchains) - elseif nm == :chain + elseif nm === :chain chainids = chains(chn) niter = size(chn, 1) return repeat(chainids; inner = niter) @@ -39,18 +45,13 @@ function Tables.getcolumn(chn::Chains, nm::Symbol) end end -Tables.rowaccess(::Type{<:Chains}) = true +# row access -function Tables.rows(chn::Chains) - _check_columnnames(chn) - return chn -end +Tables.rowaccess(::Type{<:Chains}) = true -Tables.rowtable(chn::Chains) = Tables.rowtable(Tables.columntable(chn)) +Tables.rows(chn::Chains) = Tables.rows(Tables.columntable(chn)) -function Tables.namedtupleiterator(chn::Chains) - return Tables.namedtupleiterator(Tables.columntable(chn)) -end +# optional Tables overloads function Tables.schema(chn::Chains) _check_columnnames(chn) @@ -60,6 +61,8 @@ function Tables.schema(chn::Chains) return Tables.Schema(nms, types) end +#### TableTraits interface + IteratorInterfaceExtensions.isiterable(::Chains) = true function IteratorInterfaceExtensions.getiterator(chn::Chains) return Tables.datavaluerows(Tables.columntable(chn)) @@ -67,10 +70,16 @@ end TableTraits.isiterabletable(::Chains) = true -## ChainDataFrame +#### +#### ChainDataFrame +#### + +#### Tables interface Tables.istable(::Type{<:ChainDataFrame}) = true +# AbstractColumns interface + Tables.columnaccess(::Type{<:ChainDataFrame}) = true Tables.columns(cdf::ChainDataFrame) = cdf @@ -80,21 +89,19 @@ Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i] Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm] -Tables.rowaccess(::Type{<:ChainDataFrame}) = true - -Tables.rows(cdf::ChainDataFrame) = cdf +# row access -Tables.rowtable(cdf::ChainDataFrame) = Tables.rowtable(Tables.columntable(cdf)) +Tables.rowaccess(::Type{<:ChainDataFrame}) = true -function Tables.namedtupleiterator(cdf::ChainDataFrame) - return Tables.namedtupleiterator(Tables.columntable(cdf)) -end +Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf)) function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T} types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T)) return Tables.Schema(names, types) end +#### TableTraits interface + IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame) return Tables.datavaluerows(Tables.columntable(cdf)) diff --git a/test/tables_tests.jl b/test/tables_tests.jl index 92f0c4b6..8383372d 100644 --- a/test/tables_tests.jl +++ b/test/tables_tests.jl @@ -13,63 +13,95 @@ using DataFrames @testset "Tables interface" begin @test Tables.istable(typeof(chn)) - @test Tables.columnaccess(typeof(chn)) - @test Tables.columns(chn) === chn - @test Tables.columnnames(chn) == - (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) - @test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000] - @test Tables.getcolumn(chn, :chain) == - [fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)] - @test Tables.getcolumn(chn, :a) == [ - vec(chn[:, :a, 1]) - vec(chn[:, :a, 2]) - vec(chn[:, :a, 3]) - vec(chn[:, :a, 4]) - ] - @test_throws Exception Tables.getcolumn(chn, :j) - @test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration) - @test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain) - @test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a) - @test_throws Exception Tables.getcolumn(chn, :i) - @test_throws Exception Tables.getcolumn(chn, 11) - @test Tables.rowaccess(typeof(chn)) - @test Tables.rows(chn) === chn - @test length(Tables.rowtable(chn)) == 4000 - nt = Tables.rowtable(chn)[1] - @test nt == - (; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1] - nt = Tables.rowtable(chn)[2] - @test nt == - (; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2] - @test Tables.schema(chn) isa Tables.Schema - @test Tables.schema(chn).names === - (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) - @test Tables.schema(chn).types === ( - Int, - Int, - Float64, - Float64, - Float64, - Float64, - Float64, - Float64, - Float64, - Float64, - ) - @test Tables.matrix(chn[:, :, 1])[:, 3:end] ≈ chn[:, :, 1].value - @test Tables.matrix(chn[:, :, 2])[:, 3:end] ≈ chn[:, :, 2].value - - val = rand(1000, 2, 4) - chn2 = Chains(val, ["iteration", "a"]) - @test_throws Exception Tables.columns(chn2) - @test_throws Exception Tables.rows(chn2) - @test_throws Exception Tables.schema(chn2) - chn3 = Chains(val, ["chain", "a"]) - @test_throws Exception Tables.columns(chn3) - @test_throws Exception Tables.rows(chn3) - @test_throws Exception Tables.schema(chn3) + + @testset "column access" begin + @test Tables.columnaccess(typeof(chn)) + @test Tables.columns(chn) === chn + @test Tables.columnnames(chn) == + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + @test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000] + @test Tables.getcolumn(chn, :chain) == + [fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)] + @test Tables.getcolumn(chn, :a) == [ + vec(chn[:, :a, 1]) + vec(chn[:, :a, 2]) + vec(chn[:, :a, 3]) + vec(chn[:, :a, 4]) + ] + @test_throws Exception Tables.getcolumn(chn, :j) + @test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration) + @test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain) + @test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a) + @test_throws Exception Tables.getcolumn(chn, :i) + @test_throws Exception Tables.getcolumn(chn, 11) + end + + @testset "row access" begin + @test Tables.rowaccess(typeof(chn)) + @test Tables.rows(chn) isa Tables.RowIterator + @test eltype(Tables.rows(chn)) <: Tables.AbstractRow + rows = collect(Tables.rows(chn)) + @test eltype(rows) <: Tables.AbstractRow + @test size(rows) === (4000,) + for chainid in 1:4, iterid in 1:1000 + row = rows[(chainid - 1) * 1000 + iterid] + @test Tables.columnnames(row) == + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + @test Tables.getcolumn(row, 1) == iterid + @test Tables.getcolumn(row, 2) == chainid + @test Tables.getcolumn(row, 3) == chn[iterid, :a, chainid] + @test Tables.getcolumn(row, 10) == chn[iterid, :h, chainid] + @test Tables.getcolumn(row, :iteration) == iterid + @test Tables.getcolumn(row, :chain) == chainid + @test Tables.getcolumn(row, :a) == chn[iterid, :a, chainid] + @test Tables.getcolumn(row, :h) == chn[iterid, :h, chainid] + end + end + + @testset "integration tests" begin + @test length(Tables.rowtable(chn)) == 4000 + nt = Tables.rowtable(chn)[1] + @test nt == + (; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...) + @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1] + nt = Tables.rowtable(chn)[2] + @test nt == + (; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...) + @test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2] + @test Tables.matrix(chn[:, :, 1])[:, 3:end] ≈ chn[:, :, 1].value + @test Tables.matrix(chn[:, :, 2])[:, 3:end] ≈ chn[:, :, 2].value + @test Tables.matrix(Tables.rowtable(chn)) == Tables.matrix(Tables.columntable(chn)) + end + + @testset "schema" begin + @test Tables.schema(chn) isa Tables.Schema + @test Tables.schema(chn).names === + (:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h) + @test Tables.schema(chn).types === ( + Int, + Int, + Float64, + Float64, + Float64, + Float64, + Float64, + Float64, + Float64, + Float64, + ) + end + + @testset "exceptions raised if reserved colname used" begin + val2 = rand(1000, 2, 4) + chn2 = Chains(val2, ["iteration", "a"]) + @test_throws Exception Tables.columns(chn2) + @test_throws Exception Tables.rows(chn2) + @test_throws Exception Tables.schema(chn2) + chn3 = Chains(val2, ["chain", "a"]) + @test_throws Exception Tables.columns(chn3) + @test_throws Exception Tables.rows(chn3) + @test_throws Exception Tables.schema(chn3) + end end @testset "TableTraits interface" begin @@ -82,10 +114,10 @@ using DataFrames @test nt == (; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...) - val = rand(1000, 2, 4) - chn2 = Chains(val, ["iteration", "a"]) + val2 = rand(1000, 2, 4) + chn2 = Chains(val2, ["iteration", "a"]) @test_throws Exception IteratorInterfaceExtensions.getiterator(chn2) - chn3 = Chains(val, ["chain", "a"]) + chn3 = Chains(val2, ["chain", "a"]) @test_throws Exception IteratorInterfaceExtensions.getiterator(chn3) end @@ -106,29 +138,54 @@ using DataFrames @testset "Tables interface" begin @test Tables.istable(typeof(cdf)) - @test Tables.columnaccess(typeof(cdf)) - @test Tables.columns(cdf) === cdf - @test Tables.columnnames(cdf) == keys(cdf.nt) - for (k, v) in pairs(cdf.nt) - @test Tables.getcolumn(cdf, k) == v + + @testset "column access" begin + @test Tables.columnaccess(typeof(cdf)) + @test Tables.columns(cdf) === cdf + @test Tables.columnnames(cdf) == keys(cdf.nt) + for (k, v) in pairs(cdf.nt) + @test Tables.getcolumn(cdf, k) == v + end + @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) + @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) + @test_throws Exception Tables.getcolumn(cdf, :blah) + @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) + end + + @testset "row access" begin + @test Tables.rowaccess(typeof(cdf)) + @test Tables.rows(cdf) isa Tables.RowIterator + @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow + rows = collect(Tables.rows(cdf)) + @test eltype(rows) <: Tables.AbstractRow + @test size(rows) === (2,) + @testset for i in 1:2 + row = rows[i] + @test Tables.columnnames(row) == keys(cdf.nt) + for j in length(cdf.nt) + @test Tables.getcolumn(row, j) == cdf.nt[j][i] + @test Tables.getcolumn(row, keys(cdf.nt)[j]) == cdf.nt[j][i] + end + end + end + + @testset "integration tests" begin + @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) + @test Tables.columntable(cdf) == cdf.nt + nt = Tables.rowtable(cdf)[1] + @test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...) + @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1] + nt = Tables.rowtable(cdf)[2] + @test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...) + @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2] + @test Tables.matrix(Tables.rowtable(cdf)) == Tables.matrix(Tables.columntable(cdf)) + end + + @testset "schema" begin + @test Tables.schema(cdf) isa Tables.Schema + @test Tables.schema(cdf).names === keys(cdf.nt) + @test Tables.schema(cdf).types === eltype.(values(cdf.nt)) end - @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) - @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) - @test_throws Exception Tables.getcolumn(cdf, :blah) - @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) - @test Tables.rowaccess(typeof(cdf)) - @test Tables.rows(cdf) === cdf - @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) - @test Tables.columntable(cdf) == cdf.nt - nt = Tables.rowtable(cdf)[1] - @test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1] - nt = Tables.rowtable(cdf)[2] - @test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2] - @test Tables.schema(cdf) isa Tables.Schema - @test Tables.schema(cdf).names === keys(cdf.nt) - @test Tables.schema(cdf).types === eltype.(values(cdf.nt)) end @testset "TableTraits interface" begin