Skip to content

Commit

Permalink
Complete Array API for TArray (#82)
Browse files Browse the repository at this point in the history
* Add many methods for `TArray`

Most of the newly added methods are from Tracker.jl's TrackedArray.

* remove some unnecessary methods

* unit test cases and benchmarks
  • Loading branch information
KDr2 authored Jan 5, 2021
1 parent 4e8dda9 commit fc8afbf
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 101 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ version = "0.5.0"

[deps]
Libtask_jll = "3ae2931a-708c-5973-9c38-ccf7496fb450"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Libtask_jll = "0.4"
julia = "1.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

[targets]
test = ["Test"]
test = ["Test", "BenchmarkTools"]
21 changes: 21 additions & 0 deletions deps/methods_of_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using LinearAlgebra
using Statistics

using InteractiveUtils

const MOD_METHODS = Dict{Module, Vector{Symbol}}()

methods = methodswith(AbstractArray)

for method in methods
mod = method.module
names = get!(MOD_METHODS, mod, Vector{Symbol}())
push!(names, method.name)
end

for (k, v) in MOD_METHODS
print(k)
print(":\n\t")
show(v)
print("\n\n")
end
284 changes: 186 additions & 98 deletions src/tarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end

TArray{T}(d::Integer...) where T = TArray(T, d)
TArray{T}(::UndefInitializer, d::Integer...) where T = TArray(T, d)
TArray{T}(::UndefInitializer, dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
TArray{T,N}(d::Vararg{<:Integer,N}) where {T,N} = TArray(T, d)
TArray{T,N}(::UndefInitializer, d::Vararg{<:Integer,N}) where {T,N} = TArray{T,N}(d)
TArray{T,N}(dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
Expand All @@ -43,106 +44,12 @@ function TArray(T::Type, dim)
res
end

#
# Indexing Interface Implementation
#

function Base.getindex(S::TArray{T, N}, I::Vararg{Int,N}) where {T, N}
t, d = task_local_storage(S.ref)
return d[I...]
end

function Base.setindex!(S::TArray{T, N}, x, I::Vararg{Int,N}) where {T, N}
n, d = task_local_storage(S.ref)
cn = n_copies()
newd = d
if cn > n
# println("[setindex!]: $(S.ref) copying data")
newd = deepcopy(d)
task_local_storage(S.ref, (cn, newd))
end
newd[I...] = x
end

function Base.push!(S::TArray{T}, x) where T
n, d = task_local_storage(S.ref)
cn = n_copies()
newd = d
if cn > n
newd = deepcopy(d)
task_local_storage(S.ref, (cn, newd))
end
push!(newd, x)
end
TArray(x::AbstractArray) = convert(TArray, x)

function Base.pop!(S::TArray)
n, d = task_local_storage(S.ref)
cn = n_copies()
newd = d
if cn > n
newd = deepcopy(d)
task_local_storage(S.ref, (cn, newd))
end
pop!(d)
end

function Base.convert(::Type{TArray}, x::Array)
return convert(TArray{eltype(x),ndims(x)}, x)
end
function Base.convert(::Type{TArray{T,N}}, x::Array{T,N}) where {T,N}
res = TArray{T,N}()
n = n_copies()
task_local_storage(res.ref, (n,x))
return res
end

function Base.convert(::Type{Array}, S::TArray)
return convert(Array{eltype(S), ndims(S)}, S)
end
function Base.convert(::Type{Array{T,N}}, S::TArray{T,N}) where {T,N}
n,d = task_local_storage(S.ref)
c = convert(Array{T, N}, deepcopy(d))
return c
end

function Base.display(S::TArray)
arr = S.orig_task.storage[S.ref][2]
@warn "display(::TArray) prints the originating task's storage, " *
"not the current task's storage. " *
"Please use show(::TArray) to display the current task's version of a TArray."
display(arr)
end

Base.show(io::IO, S::TArray) = Base.show(io::IO, task_local_storage(S.ref)[2])

# Base.get(t::Task, S) = S
# Base.get(t::Task, S::TArray) = (t.storage[S.ref][2])
Base.get(S::TArray) = (current_task().storage[S.ref][2])

##
# Iterator Interface
IteratorSize(::Type{TArray{T, N}}) where {T, N} = HasShape{N}()
IteratorEltype(::Type{TArray}) = HasEltype()

# Implements iterate, eltype, length, and size functions,
# as well as firstindex, lastindex, ndims, and axes
for F in (:iterate, :eltype, :length, :size,
:firstindex, :lastindex, :ndims, :axes)
@eval Base.$F(a::TArray, args...) = $F(get(a), args...)
end

#
# Similarity implementation
#

Base.similar(S::TArray) = tzeros(eltype(S), size(S))
Base.similar(S::TArray, ::Type{T}) where {T} = tzeros(T, size(S))
Base.similar(S::TArray, dims::Dims) = tzeros(eltype(S), dims)

##########
# tzeros #
##########
localize(x) = x
localize(x::AbstractArray) = TArray(x)

# Constructors
"""
tzeros(dims, ...)
Expand Down Expand Up @@ -195,3 +102,184 @@ function tfill(val::Real, dim)
task_local_storage(res.ref, (n,d))
return res
end

#
# Conversion between TArray and Array
#
_get(x) = x
function _get(x::TArray)
n, d = task_local_storage(x.ref)
return d
end

function Base.convert(::Type{Array}, x::TArray)
return convert(Array{eltype(x), ndims(x)}, x)
end
function Base.convert(::Type{Array{T,N}}, x::TArray{T,N}) where {T,N}
c = convert(Array{T, N}, deepcopy(_get(x)))
return c
end

function Base.convert(::Type{TArray}, x::AbstractArray)
return convert(TArray{eltype(x),ndims(x)}, x)
end
function Base.convert(::Type{TArray{T,N}}, x::AbstractArray{T,N}) where {T,N}
res = TArray{T,N}()
n = n_copies()
task_local_storage(res.ref, (n,x))
return res
end

#
# Representation
#
function Base.show(io::IO, ::MIME"text/plain", x::TArray)
arr = x.orig_task.storage[x.ref][2]
@warn "Here shows the originating task's storage, " *
"not the current task's storage. " *
"Please explicitly call show(::TArray) to display the current task's version of a TArray."
show(io, MIME("text/plain"), arr)
end

Base.show(io::IO, x::TArray) = Base.show(io::IO, task_local_storage(x.ref)[2])

function Base.summary(io::IO, x::TArray)
print(io, "Task Local Array: ")
summary(io, _get(x))
end

#
# Forward many methods to the underlying array
#
for F in (:size,
:iterate,
:firstindex, :lastindex, :axes)
@eval Base.$F(a::TArray, args...) = $F(_get(a), args...)
end

#
# Similarity implementation
#

Base.similar(x::TArray, ::Type{T}, dims::Dims) where T = TArray(similar(_get(x), T, dims))

for op in [:(==), :]
@eval Base.$op(x::TArray, y::AbstractArray) = Base.$op(_get(x), y)
@eval Base.$op(x::AbstractArray, y::TArray) = Base.$op(x, _get(y))
@eval Base.$op(x::TArray, y::TArray) = Base.$op(_get(x), _get(y))
end

#
# Array Stdlib
#

# Indexing Interface
function Base.getindex(x::TArray{T, N}, I::Vararg{Int,N}) where {T, N}
t, d = task_local_storage(x.ref)
return d[I...]
end

function Base.setindex!(x::TArray{T, N}, e, I::Vararg{Int,N}) where {T, N}
n, d = task_local_storage(x.ref)
cn = n_copies()
newd = d
if cn > n
# println("[setindex!]: $(x.ref) copying data")
newd = deepcopy(d)
task_local_storage(x.ref, (cn, newd))
end
newd[I...] = e
end

function Base.push!(x::TArray{T}, e) where T
n, d = task_local_storage(x.ref)
cn = n_copies()
newd = d
if cn > n
newd = deepcopy(d)
task_local_storage(x.ref, (cn, newd))
end
push!(newd, e)
end

function Base.pop!(x::TArray)
n, d = task_local_storage(x.ref)
cn = n_copies()
newd = d
if cn > n
newd = deepcopy(d)
task_local_storage(x.ref, (cn, newd))
end
pop!(d)
end

# Other methods from stdlib

Base.view(x::TArray, inds...; kwargs...) =
Base.view(_get(x), inds...; kwargs...) |> localize
Base.:-(x::TArray) = (- _get(x)) |> localize
Base.transpose(x::TArray) = transpose(_get(x)) |> localize
Base.adjoint(x::TArray) = adjoint(_get(x)) |> localize
Base.repeat(x::TArray; kw...) = repeat(_get(x); kw...) |> localize

Base.hcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
hcat(_get.(xs)...) |> localize
Base.vcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
vcat(_get.(xs)...) |> localize
Base.cat(xs::Union{TArray{T,1}, TArray{T,2}}...; dims) where T =
cat(_get.(xs)...; dims = dims) |> localize


Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(_get(x), dims) |> localize
Base.reshape(x::TArray, dims::Tuple{Vararg{Union{Int,Colon}}}) =
reshape(_get(x), Base._reshape_uncolon(_get(x), dims)) |> localize
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(_get(x), dims) |> localize

Base.permutedims(x::TArray, perm) = permutedims(_get(x), perm) |> localize
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(_get(x), perm) |> localize
Base.reverse(x::TArray; dims) = reverse(_get(x), dims = dims) |> localize

Base.sum(x::TArray; dims = :) = sum(_get(x), dims = dims) |> localize
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(_get(x))) |> localize
Base.prod(x::TArray; dims=:) = prod(_get(x); dims=dims) |> localize
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(_get(x))) |> localize

Base.findfirst(x::TArray, args...) = findfirst(_get(x), args...) |> localize
Base.maximum(x::TArray; dims = :) = maximum(_get(x), dims = dims) |> localize
Base.minimum(x::TArray; dims = :) = minimum(_get(x), dims = dims) |> localize

Base.:/(x::TArray, y::TArray) = _get(x) / _get(y) |> localize
Base.:/(x::AbstractArray, y::TArray) = x / _get(y) |> localize
Base.:/(x::TArray, y::AbstractArray) = _get(x) / y |> localize
Base.:\(x::TArray, y::TArray) = _get(x) \ _get(y) |> localize
Base.:\(x::AbstractArray, y::TArray) = x \ _get(y) |> localize
Base.:\(x::TArray, y::AbstractArray) = _get(x) \ y |> localize
Base.:*(x::TArray, y::TArray) = _get(x) * _get(y) |> localize
Base.:*(x::AbstractArray, y::TArray) = x * _get(y) |> localize
Base.:*(x::TArray, y::AbstractArray) = _get(x) * y |> localize

# broadcast
Base.BroadcastStyle(::Type{TArray{T, N}}) where {T, N} = Broadcast.ArrayStyle{TArray}()
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(_get.(args)...) |> localize

import LinearAlgebra
import LinearAlgebra: \, /, inv, det, logdet, logabsdet, norm

LinearAlgebra.inv(x::TArray) = inv(_get(x)) |> localize
LinearAlgebra.det(x::TArray) = det(_get(x)) |> localize
LinearAlgebra.logdet(x::TArray) = logdet(_get(x)) |> localize
LinearAlgebra.logabsdet(x::TArray) = logabsdet(_get(x)) |> localize
LinearAlgebra.norm(x::TArray, p::Real = 2) =
LinearAlgebra.norm(_get(x), p) |> localize

import LinearAlgebra: dot
dot(x::TArray, ys::TArray) = dot(_get(x), _get(ys)) |> localize
dot(x::AbstractArray, ys::TArray) = dot(x, _get(ys)) |> localize
dot(x::TArray, ys::AbstractArray) = dot(_get(x), ys) |> localize

using Statistics
Statistics.mean(x::TArray; dims = :) = mean(_get(x), dims = dims) |> localize
Statistics.std(x::TArray; kw...) = std(_get(x), kw...) |> localize

# TODO
# * NNlib
21 changes: 21 additions & 0 deletions test/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using BenchmarkTools
using Libtask

println("= Benchmarks on Arrays =")
A = rand(100, 100)
x, y = abs.(rand(Int, 2) .% 100)
print("indexing: ")
@btime $A[$x, $y] + $A[$x, $y]
print("set indexing: ")
@btime $A[$x, $y] = 1
print("broadcast: ")
@btime $A .+ $A

println("= Benchmarks on TArrays =")
TA = Libtask.localize(deepcopy(A))
print("indexing: ")
@btime $TA[$x, $y] + $TA[$x, $y]
print("set indexing: ")
@btime $TA[$x, $y] = 1
print("broadcast: ")
@btime $TA .+ $TA
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ using Test
include("ctask.jl")
include("tarray.jl")
include("tref.jl")

if get(ENV, "BENCHMARK", nothing) != nothing
include("benchmarks.jl")
end
Loading

0 comments on commit fc8afbf

Please sign in to comment.