Skip to content

Commit

Permalink
Add HyperGroup transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 28, 2024
1 parent 84239fb commit 21b110c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Base.replace(t::Tensor, old_new::Pair{Symbol,Symbol}...) = Tensor(parent(t), rep

Base.parent(t::Tensor) = t.data
parenttype(::Type{Tensor{T,N,A}}) where {T,N,A} = A
parenttype(::T) where {T<:Tensor} = parenttype(T)

dim(::Tensor, i::Number) = i
dim(t::Tensor, i::Symbol) = first(findall(==(i), inds(t)))
Expand Down
31 changes: 31 additions & 0 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ end
Convert hyperindices to COPY-tensors, represented by `DeltaArray`s.
This transformation is always used by default when visualizing a `TensorNetwork` with `plot`.
See also: [`HyperGroup`](@ref).
"""
struct HyperFlatten <: Transformation end

Expand Down Expand Up @@ -72,6 +74,35 @@ function transform!(tn::TensorNetwork, ::HyperFlatten)
end
end

"""
HyperGroup <: Transformation
Convert COPY-tensors, represented by `DeltaArray`s, to hyperindices.
See also: [`HyperFlatten`](@ref).
"""
struct HyperGroup <: Transformation end

function transform!(tn::TensorNetwork, ::HyperGroup)
targets = Iterators.filter(x -> parenttype(x) isa DeltaArray, tensors(tn))
for tensor in targets
# remove COPY tensor
delete!(tn, target)

# insert hyperindex
hyperindex = uuid4()

# insert weights vector
if !all(isone, delta(parent(tensor)))
push!(tn, Tensor(delta(parent(tensor)), [hyperindex]))
end

for flatindex in inds(tensor)
replace!(tn, flatindex => hyperindex)
end
end
end

"""
DiagonalReduction <: Transformation
Expand Down

0 comments on commit 21b110c

Please sign in to comment.