From d0c644095e27f683b39848ce5e2bae9f40a00a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 28 Apr 2024 21:16:23 +0200 Subject: [PATCH] Test and fix `HyperGroup` transformation --- src/Transformations.jl | 16 ++++++++++++---- test/Transformations_test.jl | 23 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/Transformations.jl b/src/Transformations.jl index 7056906c3..605fed307 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -84,13 +84,17 @@ See also: [`HyperFlatten`](@ref). struct HyperGroup <: Transformation end function transform!(tn::TensorNetwork, ::HyperGroup) - targets = Iterators.filter(x -> parenttype(x) isa DeltaArray, tensors(tn)) + targets = Iterators.filter(x -> parenttype(x) <: DeltaArray, tensors(tn)) + + open_indices = inds(tn; set=:open) + targets = Iterators.filter(t -> isdisjoint(inds(t), open_indices), targets) + for tensor in targets # remove COPY tensor - delete!(tn, target) + delete!(tn, tensor) # insert hyperindex - hyperindex = uuid4() + hyperindex = Symbol(uuid4()) # insert weights vector if !all(isone, delta(parent(tensor))) @@ -98,9 +102,13 @@ function transform!(tn::TensorNetwork, ::HyperGroup) end for flatindex in inds(tensor) - replace!(tn, flatindex => hyperindex) + tensor = pop!(tn, only(select(tn, :containing, flatindex))) + tensor = replace(tensor, flatindex => hyperindex) + push!(tn, tensor) end end + + return tn end """ diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index dc5fe0d50..074bb3738 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -33,6 +33,29 @@ # TODO @test issetequal(neighbours()) end + @testset "HyperGroup" begin + using Tenet: HyperGroup + + @testset "open indices" begin + tn = TensorNetwork([Tensor(DeltaArray{3}(ones(2)), [:i, :j, :k])]) + transform!(tn, HyperGroup) + + @test isempty(inds(tn, :hyper)) + end + + @testset "closed indices" begin + tn = TensorNetwork([ + Tensor(rand(2), [:i]), + Tensor(rand(2), [:j]), + Tensor(rand(2), [:k]), + Tensor(DeltaArray{3}(ones(2)), [:i, :j, :k]), + ]) + transform!(tn, HyperGroup) + + @test length(inds(tn, :hyper)) == 1 + end + end + @testset "DiagonalReduction" begin using Tenet: DiagonalReduction, find_diag_axes