From 35b684fed307cefda555e21d425c363f702401cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 10 Jun 2024 17:05:00 +0200 Subject: [PATCH] Integrate `einsum` on `Dagger.DArray` --- Project.toml | 2 + ext/MuscleDaggerExt/Einsum.jl | 107 +++++++++++++++++++++++++ ext/MuscleDaggerExt/MuscleDaggerExt.jl | 5 ++ 3 files changed, 114 insertions(+) create mode 100644 ext/MuscleDaggerExt/Einsum.jl create mode 100644 ext/MuscleDaggerExt/MuscleDaggerExt.jl diff --git a/Project.toml b/Project.toml index 65f655a..b2464ee 100644 --- a/Project.toml +++ b/Project.toml @@ -11,9 +11,11 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] +MuscleDaggerExt = "Dagger" MuscleReactantExt = "Reactant" [compat] diff --git a/ext/MuscleDaggerExt/Einsum.jl b/ext/MuscleDaggerExt/Einsum.jl new file mode 100644 index 0000000..4bc983f --- /dev/null +++ b/ext/MuscleDaggerExt/Einsum.jl @@ -0,0 +1,107 @@ +using Muscle +using Dagger: Dagger, ArrayOp, Context, ArrayDomain, EagerThunk, DArray +using LinearAlgebra + +struct Einsum{T,N} <: ArrayOp{T,N} + ic::Vector{Symbol} + a::ArrayOp + ia::Vector{Symbol} + b::ArrayOp + ib::Vector{Symbol} + + function Einsum(ic, a, ia, b, ib) + allunique(ia) || throw(ErrorException("ia must have unique indices")) + allunique(ib) || throw(ErrorException("ib must have unique indices")) + allunique(ic) || throw(ErrorException("ic must have unique indices")) + ic ⊆ ia ∪ ib || throw(ErrorException("ic must be a subset of ia ∪ ib")) + return new{Base.promote_eltype(a, b),length(ic)}(ic, a, ia, b, ib) + end +end + +function Base.size(@nospecialize(x::Einsum)) + return Tuple( + Iterators.map(x.ic) do i + if i ∈ x.ia + size(x.a, findfirst(==(i), x.ia)) + elseif i ∈ x.ib + size(x.b, findfirst(==(i), x.ib)) + else + throw(ErrorException("index $i not found in a nor b")) + end + end, + ) +end + +function Dagger.Blocks(@nospecialize(x::Einsum)) + return Dagger.Blocks(map(x.ic) do i + j = findfirst(==(i), x.ia) + isnothing(j) || return x.a.partitioning.blocksize[j] + + j = findfirst(==(i), x.ib) + isnothing(j) || return x.b.partitioning.blocksize[j] + + throw(ErrorException("index :$i not found in a nor b")) + end...) +end + +function Dagger.stage(ctx::Context, op::Einsum{T,N}) where {T,N} + domain = Dagger.ArrayDomain([1:l for l in size(op)]) + partitioning = Dagger.Blocks(op) + + # NOTE careful with ÷ for dividing into partitions + subdomains = Array{ArrayDomain{N,NTuple{2,UnitRange{Int}}}}(undef, map(÷, size(op), partitioning.blocksize)) + for indices in eachindex(IndexCartesian(), subdomains) + subdomains[indices] = ArrayDomain( + map(Tuple(indices), partitioning.blocksize) do i, step + (i - 1) * step .+ (1:step) + end, + ) + end + + suminds = setdiff(op.ia ∪ op.ib, op.ic) + inner_perm_a = sortperm(map(i -> findfirst(==(i), op.ia), suminds)) + inner_perm_b = sortperm(map(i -> findfirst(==(i), op.ib), suminds)) + + mask_a = op.ic .∈ (op.ia,) + mask_b = op.ic .∈ (op.ib,) + outer_perm_a = map(i -> findfirst(==(i), op.ia), op.ic[mask_a]) + outer_perm_b = map(i -> findfirst(==(i), op.ib), op.ic[mask_b]) + + chunks = similar(subdomains, EagerThunk) + for indices in eachindex(IndexCartesian(), chunks) + outer_indices_a = Tuple(indices)[mask_a] + chunks_a = dropdims( + reduce(zip(outer_perm_a, outer_indices_a); init=Dagger.chunks(op.a)) do acc, (d, i) + selectdim(acc, d, i:i) + end; + dims=Tuple(outer_perm_a), + ) + chunks_a = permutedims(chunks_a, inner_perm_a) + + outer_indices_b = Tuple(indices)[mask_b] + chunks_b = dropdims( + reduce(zip(outer_perm_b, outer_indices_b); init=Dagger.chunks(op.b)) do acc, (d, i) + selectdim(acc, d, i:i) + end; + dims=Tuple(outer_perm_b), + ) + chunks_b = permutedims(chunks_b, inner_perm_b) + + chunks[indices] = Dagger.treereduce( + Dagger.AddComputeOp, + map(chunks_a, chunks_b) do chunk_a, chunk_b + # TODO add ThunkOptions: alloc_util, occupancy, ... + Dagger.@spawn einsum(op.ic, chunk_a, op.ia, chunk_b, op.ib) + end, + ) + end + + return DArray(T, domain, subdomains, chunks, partitioning) +end + +Muscle.einsum(ic, a::Dagger.DArray, ia, b::Dagger.DArray, ib) = Dagger._to_darray(Einsum(ic, a, ia, b, ib)) + +Muscle.einsum(a, b::Dagger.DArray; kwargs...) = contract(b, a; kwargs...) +function Muscle.einsum(a::Dagger.DArray, b; kwargs...) + throw(ArgumentError("einsum on a Dagger.DArray with a non-DArray is not yet supported")) +end diff --git a/ext/MuscleDaggerExt/MuscleDaggerExt.jl b/ext/MuscleDaggerExt/MuscleDaggerExt.jl new file mode 100644 index 0000000..ce79c55 --- /dev/null +++ b/ext/MuscleDaggerExt/MuscleDaggerExt.jl @@ -0,0 +1,5 @@ +module MuscleDaggerExt + +include("Einsum.jl") + +end \ No newline at end of file