From 21b110c0ec7ed965e6f06868b76843cb5d056832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 28 Apr 2024 14:38:39 +0200 Subject: [PATCH] Add `HyperGroup` transformation --- src/Tensor.jl | 1 + src/Transformations.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/Tensor.jl b/src/Tensor.jl index 03e44ceee..31c4df7f5 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -89,6 +89,7 @@ Base.replace(t::Tensor, old_new::Pair{Symbol,Symbol}...) = Tensor(parent(t), rep Base.parent(t::Tensor) = t.data parenttype(::Type{Tensor{T,N,A}}) where {T,N,A} = A +parenttype(::T) where {T<:Tensor} = parenttype(T) dim(::Tensor, i::Number) = i dim(t::Tensor, i::Symbol) = first(findall(==(i), inds(t))) diff --git a/src/Transformations.jl b/src/Transformations.jl index 6470be14e..03413de13 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -41,6 +41,8 @@ end Convert hyperindices to COPY-tensors, represented by `DeltaArray`s. This transformation is always used by default when visualizing a `TensorNetwork` with `plot`. + +See also: [`HyperGroup`](@ref). """ struct HyperFlatten <: Transformation end @@ -72,6 +74,35 @@ function transform!(tn::TensorNetwork, ::HyperFlatten) end end +""" + HyperGroup <: Transformation + +Convert COPY-tensors, represented by `DeltaArray`s, to hyperindices. + +See also: [`HyperFlatten`](@ref). +""" +struct HyperGroup <: Transformation end + +function transform!(tn::TensorNetwork, ::HyperGroup) + targets = Iterators.filter(x -> parenttype(x) isa DeltaArray, tensors(tn)) + for tensor in targets + # remove COPY tensor + delete!(tn, target) + + # insert hyperindex + hyperindex = uuid4() + + # insert weights vector + if !all(isone, delta(parent(tensor))) + push!(tn, Tensor(delta(parent(tensor)), [hyperindex])) + end + + for flatindex in inds(tensor) + replace!(tn, flatindex => hyperindex) + end + end +end + """ DiagonalReduction <: Transformation