From 37122911c24f44318e6d4a0840408adb3364cf2a Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Tue, 5 Dec 2023 12:33:24 -0500 Subject: [PATCH] enable field-order-agnostic overloads of `fromarrow` for struct types (#493) Motivated by https://github.com/beacon-biosignals/Legolas.jl/issues/94#issuecomment-1837366852 Still requires: - [x] docs - [x] a test - [x] a bit more due diligence benchmarking-wise. `@benchmark`ing the access in the test case from https://github.com/beacon-biosignals/Legolas.jl/issues/94 didn't reveal any perf difference, which seems like a good sign --------- Co-authored-by: Eric Hanson <5846501+ericphanson@users.noreply.github.com> --- Project.toml | 2 +- src/ArrowTypes/Project.toml | 2 +- src/ArrowTypes/src/ArrowTypes.jl | 8 ++++++- src/arraytypes/struct.jl | 37 +++++++++++++++++++------------- src/table.jl | 3 ++- test/runtests.jl | 29 ++++++++++++++++++++++++- 6 files changed, 61 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 80e6073d..1e3ffe3c 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ name = "Arrow" uuid = "69666777-d1a9-59fb-9406-91d4454c9d45" authors = ["quinnj "] -version = "2.6.3" +version = "2.7.0" [deps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" diff --git a/src/ArrowTypes/Project.toml b/src/ArrowTypes/Project.toml index 95f411c0..0166f602 100644 --- a/src/ArrowTypes/Project.toml +++ b/src/ArrowTypes/Project.toml @@ -18,7 +18,7 @@ name = "ArrowTypes" uuid = "31f734f8-188a-4ce0-8406-c8a06bd891cd" authors = ["quinnj "] -version = "2.2.2" +version = "2.3.0" [deps] Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl index bd67c5fd..86183b54 100644 --- a/src/ArrowTypes/src/ArrowTypes.jl +++ b/src/ArrowTypes/src/ArrowTypes.jl @@ -170,11 +170,15 @@ overload is necessary. A few `ArrowKind`s have/allow slightly more custom overloads for their `fromarrow` methods: * `ListKind{true}`: for `String` types, they may overload `fromarrow(::Type{T}, ptr::Ptr{UInt8}, len::Int) = ...` to avoid materializing a `String` - * `StructKind`: must overload `fromarrow(::Type{T}, x...)` where individual fields are passed as separate + * `StructKind`: + * May overload `fromarrow(::Type{T}, x...)` where individual fields are passed as separate positional arguments; so if my custom type `Interval` has two fields `first` and `last`, then I'd overload like `ArrowTypes.fromarrow(::Type{Interval}, first, last) = ...`. Note the default implementation is `ArrowTypes.fromarrow(::Type{T}, x...) = T(x...)`, so if your type already accepts all arguments in a constructor no additional `fromarrow` method should be necessary (default struct constructors have this behavior). + * Alternatively, may overload `fromarrowstruct(::Type{T}, ::Val{fnames}, x...)`, where `fnames` is a tuple of the + field names corresponding to the values in `x`. This approach is useful when you need to implement deserialization + in a manner that is agnostic to the field order used by the serializer. When implemented, `fromarrowstruct` takes precedence over `fromarrow` in `StructKind` deserialization. """ function fromarrow end fromarrow(::Type{T}, x::T) where {T} = x @@ -302,6 +306,8 @@ struct StructKind <: ArrowKind end ArrowKind(::Type{<:NamedTuple}) = StructKind() +@inline fromarrowstruct(T::Type, ::Val, x...) = fromarrow(T, x...) + fromarrow( ::Type{NamedTuple{names,types}}, x::NamedTuple{names,types}, diff --git a/src/arraytypes/struct.jl b/src/arraytypes/struct.jl index 4ad97526..510d1e46 100644 --- a/src/arraytypes/struct.jl +++ b/src/arraytypes/struct.jl @@ -19,7 +19,7 @@ An `ArrowVector` where each element is a "struct" of some kind with ordered, named fields, like a `NamedTuple{names, types}` or regular julia `struct`. """ -struct Struct{T,S} <: ArrowVector{T} +struct Struct{T,S,fnames} <: ArrowVector{T} validity::ValidityBitmap data::S # Tuple of ArrowVector ℓ::Int @@ -33,23 +33,29 @@ isnamedtuple(T) = false istuple(::Type{<:Tuple}) = true istuple(T) = false -@propagate_inbounds function Base.getindex(s::Struct{T,S}, i::Integer) where {T,S} +if isdefined(ArrowTypes, :fromarrowstruct) + # https://github.com/apache/arrow-julia/pull/493 + @inline function _fromarrowstruct(T::Type, v::Val, x...) + return ArrowTypes.fromarrowstruct(T, v, x...) + end +else + @inline function _fromarrowstruct(T::Type, ::Val, x...) + return ArrowTypes.fromarrow(T, x...) + end +end + +@propagate_inbounds function Base.getindex( + s::Struct{T,S,fnames}, + i::Integer, +) where {T,S,fnames} @boundscheck checkbounds(s, i) NT = Base.nonmissingtype(T) + NT !== T && (s.validity[i] || return missing) + vals = ntuple(j -> s.data[j][i], fieldcount(S)) if isnamedtuple(NT) || istuple(NT) - if NT !== T - return s.validity[i] ? NT(ntuple(j -> s.data[j][i], fieldcount(S))) : missing - else - return NT(ntuple(j -> s.data[j][i], fieldcount(S))) - end + return NT(vals) else - if NT !== T - return s.validity[i] ? - ArrowTypes.fromarrow(NT, (s.data[j][i] for j = 1:fieldcount(S))...) : - missing - else - return ArrowTypes.fromarrow(NT, (s.data[j][i] for j = 1:fieldcount(S))...) - end + return _fromarrowstruct(NT, Val{fnames}(), vals...) end end @@ -100,7 +106,8 @@ function arrowvector(::StructKind, x, i, nl, fi, de, ded, meta; kw...) arrowvector(ToStruct(x, j), i, nl + 1, j, de, ded, nothing; kw...) for j = 1:fieldcount(T) ) - return Struct{withmissing(eltype(x), namedtupletype(T, data)),typeof(data)}( + NT = namedtupletype(T, data) + return Struct{withmissing(eltype(x), NT),typeof(data),fieldnames(NT)}( validity, data, len, diff --git a/src/table.jl b/src/table.jl index 882a99b1..ecd8b1d8 100644 --- a/src/table.jl +++ b/src/table.jl @@ -840,7 +840,8 @@ function build(f::Meta.Field, L::Meta.Struct, batch, rb, de, nodeidx, bufferidx, data = Tuple(vecs) meta = buildmetadata(f.custom_metadata) T = juliaeltype(f, meta, convert) - return Struct{T,typeof(data)}(validity, data, len, meta), nodeidx, bufferidx + fnames = ntuple(i -> Symbol(f.children[i].name), length(f.children)) + return Struct{T,typeof(data),fnames}(validity, data, len, meta), nodeidx, bufferidx end function build(f::Meta.Field, L::Meta.Union, batch, rb, de, nodeidx, bufferidx, convert) diff --git a/test/runtests.jl b/test/runtests.jl index 48ca399b..ed288b3d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1014,6 +1014,33 @@ end # @test isequal(table.v, table2.v) # end - + if isdefined(ArrowTypes, :StructElement) + @testset "# 493" begin + # This test stresses the existence of the mechanism + # implemented in https://github.com/apache/arrow-julia/pull/493, + # but doesn't stress the actual use case that motivates + # that mechanism, simply because it'd be more annoying to + # write that test; see the PR for details. + struct Foo493 + x::Int + y::Int + end + ArrowTypes.arrowname(::Type{Foo493}) = Symbol("JuliaLang.Foo493") + ArrowTypes.JuliaType(::Val{Symbol("JuliaLang.Foo493")}, T) = Foo493 + function ArrowTypes.fromarrowstruct( + ::Type{Foo493}, + ::Val{fnames}, + x..., + ) where {fnames} + nt = NamedTuple{fnames}(x) + return Foo493(nt.x + 1, nt.y + 1) + end + t = (; f=[Foo493(1, 2), Foo493(3, 4)]) + buf = Arrow.tobuffer(t) + tbl = Arrow.Table(buf) + @test tbl.f[1] === Foo493(2, 3) + @test tbl.f[2] === Foo493(4, 5) + end + end end # @testset "misc" end