Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make StridedReinterpretArray's get/setindex pointer based. #44186

Merged
merged 9 commits into from
Nov 8, 2023
Prev Previous commit
Next Next commit
Make ReinterpretArray's indexing pure pointer based
if its root parent isa `Array` and it is dense like.
Also add missing `pointer` for `FasterContiguousSubArray`
N5N3 committed Apr 29, 2022
commit a6d13468d29abdc7f3e1d6efedaedb7a49a8d9e0
28 changes: 28 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
@@ -344,15 +344,22 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
end
end

check_store(a::StridedReinterpretArray) = check_store(parent(a))
check_store(a::FastContiguousSubArray) = check_store(parent(a))
check_store(a::Array) = true
check_store(a::AbstractArray) = false

@propagate_inbounds getindex(a::ReinterpretArray) = a[firstindex(a)]

@propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
check_readable(a)
check_store(a) && return _getindex_ptr(a, inds...)
_getindex_ra(a, inds[1], tail(inds))
end

@propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, i::Int) where {T,N,S}
check_readable(a)
check_store(a) && return _getindex_ptr(a, i)
if isa(IndexStyle(a), IndexLinear)
return _getindex_ra(a, i, ())
end
@@ -373,6 +380,15 @@ end

@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)

@inline function _getindex_ptr(a::ReinterpretArray{T}, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
GC.@preserve a begin
p = pointer(a) + sizeof(T) * (li - 1)
return unsafe_load(p)
end
end

@propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
@@ -488,11 +504,13 @@ end

@propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
check_writable(a)
check_store(a) && return _setindex_ptr!(a, v, inds...)
_setindex_ra!(a, v, inds[1], tail(inds))
end

@propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, i::Int) where {T,N,S}
check_writable(a)
check_store(a) && return _setindex_ptr!(a, v, i)
if isa(IndexStyle(a), IndexLinear)
return _setindex_ra!(a, v, i, ())
end
@@ -512,6 +530,16 @@ end
return a
end

@inline function _setindex_ptr!(a::ReinterpretArray{T}, v, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
GC.@preserve a begin
p = pointer(a) + sizeof(T) * (li - 1)
unsafe_store!(p, v)
end
return a
end

@propagate_inbounds function _setindex_ra!(a::NonReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
2 changes: 1 addition & 1 deletion base/subarray.jl
Original file line number Diff line number Diff line change
@@ -432,10 +432,10 @@ find_extended_inds() = ()
function unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P}
return unsafe_convert(Ptr{T}, V.parent) + _memory_offset(V.parent, map(first, V.indices)...)
end
unsafe_convert(::Type{Ptr{T}}, V::FastContiguousSubArray{T}) where {T} = unsafe_convert(Ptr{T}, V.parent) + V.offset1 * sizeof(T)

pointer(V::FastSubArray, i::Int) = pointer(V.parent, V.offset1 + V.stride1*i)
pointer(V::FastContiguousSubArray, i::Int) = pointer(V.parent, V.offset1 + i)

function pointer(V::SubArray{<:Any,<:Any,<:Array,<:Tuple{Vararg{RangeIndex}}}, is::AbstractCartesianIndex{N}) where {N}
index = first_index(V)
strds = strides(V)
10 changes: 10 additions & 0 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
@@ -508,3 +508,13 @@ end
@test setindex!(x, SomeSingleton(:), 3, 5) == x2
@test_throws MethodError x[2,4] = nothing
end

@testset "pointer for StridedArray" begin
a = rand(Float64, 251)
v = view(a, UInt(2):UInt(251));
A = reshape(v, 25, 10);
@test A isa StridedArray && pointer(A) === pointer(a, 2)
Av = view(A, 1:20, 1:2)
@test Av isa StridedArray && pointer(Av) === pointer(a, 2)
@test Av * Av' isa Array
end