Skip to content

Commit

Permalink
clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
Moelf committed Nov 15, 2024
1 parent a2e5b80 commit c346672
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 124 deletions.
14 changes: 9 additions & 5 deletions src/RNTuple/Writing/TFileWriter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,9 @@ end

# primary case
function add_field_column_record!(field_records, column_records, input_T::Type{<:Real}, NAME; parent_field_id, col_field_id = parent_field_id)
fr = UnROOT.FieldRecord(zero(UInt32), zero(UInt32), parent_field_id, zero(UInt16), zero(UInt16), string(NAME), RNTUPLE_WRITE_TYPE_CPPNAME_DICT[input_T], "", "", 0, -1, -1)
cr = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[input_T]..., col_field_id, 0x00, 0x00, 0)
fr = UnROOT.FieldRecord(zero(UInt32), zero(UInt32), parent_field_id, zero(UInt16), zero(UInt16), string(NAME), RNT_WRITE_CPP_TYPE_NAME_DICT[input_T], "", "", 0, -1, -1)
rnt_col_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[input_T] + 1]
cr = UnROOT.ColumnRecord(rnt_col_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0)
push!(field_records, fr)
push!(column_records, cr)
nothing
Expand All @@ -487,9 +488,11 @@ function add_field_column_record!(field_records, column_records, input_T::Type{<
fr = UnROOT.FieldRecord(; field_version=0x00000000, type_version=0x00000000, parent_field_id, struct_role=0x0000, flags=0x0000, repetition=0, source_field_id=-1, root_streamer_checksum=-1, field_name=string(NAME), type_name="std::string", type_alias="", field_desc="", )
push!(field_records, fr)

cr_offset = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Index64]..., col_field_id, 0x00, 0x00, 0)
rnt_indexcol_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Index64] + 1]
cr_offset = UnROOT.ColumnRecord(rnt_indexcol_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0)

Check warning on line 492 in src/RNTuple/Writing/TFileWriter.jl

View check run for this annotation

Codecov / codecov/patch

src/RNTuple/Writing/TFileWriter.jl#L491-L492

Added lines #L491 - L492 were not covered by tests
push!(column_records, cr_offset)
cr_chars = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Char]..., col_field_id, 0x00, 0x00, 0)
rnt_charcol_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Char] + 1]
cr_chars = UnROOT.ColumnRecord(rnt_charcol_type.type, rnt_char_type.nbits, col_field_id, 0x00, 0x00, 0)

Check warning on line 495 in src/RNTuple/Writing/TFileWriter.jl

View check run for this annotation

Codecov / codecov/patch

src/RNTuple/Writing/TFileWriter.jl#L494-L495

Added lines #L494 - L495 were not covered by tests
push!(column_records, cr_chars)
nothing
end
Expand All @@ -499,7 +502,8 @@ function add_field_column_record!(field_records, column_records, input_T::Type{<
implicit_field_id = length(field_records)
fr = UnROOT.FieldRecord(; field_version=0x00000000, type_version=0x00000000, parent_field_id, struct_role=0x0001, flags=0x0000, repetition=0, source_field_id=-1, root_streamer_checksum=-1, field_name=string(NAME), type_name="", type_alias="", field_desc="", )
push!(field_records, fr)
cr_offset = UnROOT.ColumnRecord(RNTUPLE_WRITE_TYPE_IDX_DICT[Index64]..., col_field_id, 0x00, 0x00, 0)
rnt_col_type = RNT_COL_TYPE_TABLE[RNT_WRITE_JL_TYPE_DICT[Index64] + 1]
cr_offset = UnROOT.ColumnRecord(rnt_col_type.type, rnt_col_type.nbits, col_field_id, 0x00, 0x00, 0)
push!(column_records, cr_offset)

# TODO: this feels like a hack, think about it more
Expand Down
94 changes: 17 additions & 77 deletions src/RNTuple/Writing/page_writing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
rnt_col_to_ary(col) -> Vector{Vector}
Normalize each user-facing "column" into a collection of Vector{<:Real} ready to be written to a page.
After calling this on all user-facing "column", we should have as many `ary`s as our `ColumnRecord`s.
After calling this on all user-facing "column", we should have as many `ary`s as our `ColumnRecord`s and
in the same order.
"""
rnt_col_to_ary(col::AbstractVector{<:Real}) = Any[col]

function rnt_col_to_ary(col::AbstractVector{<:AbstractVector})
vov = VectorOfVectors(col)
content = flatview(vov)
Expand All @@ -15,7 +15,6 @@ function rnt_col_to_ary(col::AbstractVector{<:AbstractVector})

Any[rnt_col_to_ary(offset_adjust); rnt_col_to_ary(content)]
end

function rnt_col_to_ary(col::AbstractVector{<:AbstractString})
rnt_col_to_ary(codeunits.(col))
end
Expand All @@ -28,90 +27,31 @@ Turns an AbstractVector into a page of an RNTuple. The element type must be prim
"""
function rnt_ary_to_page(ary::AbstractVector, cr::ColumnRecord) end


function rnt_ary_to_page(ary::AbstractVector{Bool}, cr::ColumnRecord)
chunks = BitVector(ary).chunks
Page_write(reinterpret(UInt8, chunks))
end

function rnt_ary_to_page(ary::AbstractVector{Float64}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split8_encode(reinterpret(UInt8, ary)))
else
Page_write(reinterpret(UInt8, ary))
end
end

function rnt_ary_to_page(ary::AbstractVector{Float32}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split4_encode(reinterpret(UInt8, ary)))
else
Page_write(reinterpret(UInt8, ary))
end
end

function rnt_ary_to_page(ary::AbstractVector{Float16}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split2_encode(reinterpret(UInt8, ary)))
else
Page_write(reinterpret(UInt8, ary))
end
end

function rnt_ary_to_page(ary::AbstractVector{UInt64}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split8_encode(reinterpret(UInt8, ary)))
else
Page_write(reinterpret(UInt8, ary))
end
end

function rnt_ary_to_page(ary::AbstractVector{UInt32}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split4_encode(reinterpret(UInt8, ary)))
else
Page_write(reinterpret(UInt8, ary))
end
function rnt_ary_to_page(ary::AbstractVector{T}, cr::ColumnRecord) where T<:Number
Page_write(page_encode(ary, cr))
end

function rnt_ary_to_page(ary::AbstractVector{UInt16}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
if split
Page_write(split2_encode(reinterpret(UInt8, ary)))
function page_encode(ary::AbstractVector{T}, cr::ColumnRecord) where T
col_type = RNT_COL_TYPE_TABLE[cr.type+1]
nbits = col_type.nbits
src = reinterpret(UInt8, ary)
if col_type.issplit
if nbits == 64
split8_encode(src)
elseif nbits == 32
split4_encode(src)
elseif nbits == 16
split2_encode(src)

Check warning on line 49 in src/RNTuple/Writing/page_writing.jl

View check run for this annotation

Codecov / codecov/patch

src/RNTuple/Writing/page_writing.jl#L44-L49

Added lines #L44 - L49 were not covered by tests
end
else
Page_write(reinterpret(UInt8, ary))
src
end
end

function rnt_ary_to_page(ary::AbstractVector{UInt8}, cr::ColumnRecord)
Page_write(ary)
end

function rnt_ary_to_page(ary::AbstractVector{Int64}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int32}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int16}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
Page_write(reinterpret(UInt8, ary))
end

function rnt_ary_to_page(ary::AbstractVector{Int8}, cr::ColumnRecord)
(;split, zigzag, delta) = _detect_encoding(cr.type)
Page_write(reinterpret(UInt8, ary))
end

function split8_encode(src::AbstractVector{UInt8})
@views [src[1:8:end-7]; src[2:8:end-6]; src[3:8:end-5]; src[4:8:end-4]; src[5:8:end-3]; src[6:8:end-2]; src[7:8:end-1]; src[8:8:end]]
end
Expand Down
39 changes: 20 additions & 19 deletions src/RNTuple/constants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
@define_integers 64 SignedIndex64 Index64
Base.promote_rule(::Type{Int64}, ::Type{Index64}) = Int64
Base.promote_rule(::Type{Index64}, ::Type{Int64}) = Int64
Base.promote_rule(::Type{Int64}, ::Type{Index32}) = Int64
Base.promote_rule(::Type{Index32}, ::Type{Int64}) = Int64

Check warning on line 8 in src/RNTuple/constants.jl

View check run for this annotation

Codecov / codecov/patch

src/RNTuple/constants.jl#L7-L8

Added lines #L7 - L8 were not covered by tests

@kwdef struct RNTuple_ColumnType
type::UInt8
Expand All @@ -16,7 +18,7 @@ Base.promote_rule(::Type{Index64}, ::Type{Int64}) = Int64
end

#https://github.com/root-project/root/blob/1de46e89958fd3946d2d6995c810391b781d39ac/tree/ntuple/v7/doc/BinaryFormatSpecification.md?plain=1#L479
const rntuple_col_type_table = (
const RNT_COL_TYPE_TABLE = (
RNTuple_ColumnType(type = 0x00, nbits = 1, name = :Bit , jltype = Bool),
RNTuple_ColumnType(type = 0x01, nbits = 8, name = :Byte , jltype = UInt8),
RNTuple_ColumnType(type = 0x02, nbits = 8, name = :Char , jltype = UInt8),
Expand Down Expand Up @@ -49,26 +51,25 @@ RNTuple_ColumnType(type = 0x1B, nbits = 64, name = :SplitIndex64, jltype = Index
# (0x1D, 1-32, :Real32Quant ), #??
)

# for each Julia type, we pick just one canonical representation for writing
const RNTUPLE_WRITE_TYPE_IDX_DICT = Dict(
Index64 => (0x0F, sizeof(Index64) * 8),
Index32 => (0x0E, sizeof(Index32) * 8),
Char => (0x02, 8),
Bool => (0x00, 1),
Float64 => (0x0D, sizeof(Float64) * 8),
Float32 => (0x0C, sizeof(Float32) * 8),
Float16 => (0x0B, sizeof(Float16) * 8),
UInt64 => (0x0A, sizeof(UInt64) * 8),
UInt32 => (0x08, sizeof(UInt32) * 8),
UInt16 => (0x06, sizeof(UInt16) * 8),
UInt8 => (0x04, sizeof(UInt8) * 8),
Int64 => (0x09, sizeof(Int64) * 8),
Int32 => (0x07, sizeof(Int32) * 8),
Int16 => (0x05, sizeof(Int16) * 8),
Int8 => (0x03, sizeof(Int8) * 8),
const RNT_WRITE_JL_TYPE_DICT = Dict(
Index64 => 0x0F,
Index32 => 0x0E,
Char => 0x02,
Bool => 0x00,
Float64 => 0x0D,
Float32 => 0x0C,
Float16 => 0x0B,
UInt64 => 0x0A,
UInt32 => 0x08,
UInt16 => 0x06,
UInt8 => 0x04,
Int64 => 0x09,
Int32 => 0x07,
Int16 => 0x05,
Int8 => 0x03,
)

const RNTUPLE_WRITE_TYPE_CPPNAME_DICT = Dict(
const RNT_WRITE_CPP_TYPE_NAME_DICT = Dict(
Bool => "bool",
Float16 => "std::float16_t",
Float32 => "float",
Expand Down
2 changes: 1 addition & 1 deletion src/RNTuple/fieldcolumn_reading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function read_field(io, field::UnionField{S, T}, page_list) where {S, T}
end

function _detect_encoding(typenum)
col_type = rntuple_col_type_table[typenum+1]
col_type = RNT_COL_TYPE_TABLE[typenum+1]
split = col_type.issplit
zigzag = col_type.iszigzag
delta = col_type.isdelta
Expand Down
4 changes: 2 additions & 2 deletions src/RNTuple/fieldcolumn_schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ function _search_col_type(field_id, column_records, col_id::Int...)
index_record = column_records[col_id[1]]
char_record = column_records[col_id[2]]
index_typenum = index_record.type
LeafType = rntuple_col_type_table[index_typenum+0x01].jltype
LeafType = RNT_COL_TYPE_TABLE[index_typenum+0x01].jltype
return StringField(
LeafField{LeafType}(col_id[1],index_record),
LeafField{Char}(col_id[2], char_record)
)
elseif length(col_id) == 1
record = column_records[only(col_id)]
LeafType = rntuple_col_type_table[record.type+0x01].jltype
LeafType = RNT_COL_TYPE_TABLE[record.type+0x01].jltype
return LeafField{LeafType}(only(col_id), record)
else
error("un-handled RNTuple case, report issue to UnROOT.jl")
Expand Down
40 changes: 20 additions & 20 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@ using UnROOT
nthreads = UnROOT._maxthreadid()
nthreads == 1 && @warn "Running on a single thread. Please re-run the test suite with at least two threads (`julia --threads 2 ...`)"

@testset "UnROOT tests" verbose = true begin
include("Aqua.jl")
include("bootstrapping.jl")
include("compressions.jl")
include("jagged.jl")
include("lazy.jl")
include("histograms.jl")
include("views.jl")
include("multithreading.jl")
include("remote.jl")
include("displays.jl")
include("type_stability.jl")
include("utils.jl")
include("misc.jl")
# @testset "UnROOT tests" verbose = true begin
# include("Aqua.jl")
# include("bootstrapping.jl")
# include("compressions.jl")
# include("jagged.jl")
# include("lazy.jl")
# include("histograms.jl")
# include("views.jl")
# include("multithreading.jl")
# include("remote.jl")
# include("displays.jl")
# include("type_stability.jl")
# include("utils.jl")
# include("misc.jl")

include("type_support.jl")
include("custom_bootstrapping.jl")
include("lorentzvectors.jl")
include("NanoAOD.jl")
# include("type_support.jl")
# include("custom_bootstrapping.jl")
# include("lorentzvectors.jl")
# include("NanoAOD.jl")

include("issues.jl")
# include("issues.jl")

if VERSION >= v"1.9"
include("rntuple.jl")
include("./RNTupleWriting/lowlevel.jl")
end
end
# end

0 comments on commit c346672

Please sign in to comment.