Skip to content

Commit

Permalink
Test and fix HyperGroup transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 28, 2024
1 parent 122715b commit d0c6440
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,31 @@ See also: [`HyperFlatten`](@ref).
struct HyperGroup <: Transformation end

function transform!(tn::TensorNetwork, ::HyperGroup)
targets = Iterators.filter(x -> parenttype(x) isa DeltaArray, tensors(tn))
targets = Iterators.filter(x -> parenttype(x) <: DeltaArray, tensors(tn))

open_indices = inds(tn; set=:open)
targets = Iterators.filter(t -> isdisjoint(inds(t), open_indices), targets)

for tensor in targets
# remove COPY tensor
delete!(tn, target)
delete!(tn, tensor)

# insert hyperindex
hyperindex = uuid4()
hyperindex = Symbol(uuid4())

# insert weights vector
if !all(isone, delta(parent(tensor)))
push!(tn, Tensor(delta(parent(tensor)), [hyperindex]))
end

for flatindex in inds(tensor)
replace!(tn, flatindex => hyperindex)
tensor = pop!(tn, only(select(tn, :containing, flatindex)))
tensor = replace(tensor, flatindex => hyperindex)
push!(tn, tensor)
end
end

return tn
end

"""
Expand Down
23 changes: 23 additions & 0 deletions test/Transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@
# TODO @test issetequal(neighbours())
end

@testset "HyperGroup" begin
using Tenet: HyperGroup

@testset "open indices" begin
tn = TensorNetwork([Tensor(DeltaArray{3}(ones(2)), [:i, :j, :k])])
transform!(tn, HyperGroup)

@test isempty(inds(tn, :hyper))
end

@testset "closed indices" begin
tn = TensorNetwork([
Tensor(rand(2), [:i]),
Tensor(rand(2), [:j]),
Tensor(rand(2), [:k]),
Tensor(DeltaArray{3}(ones(2)), [:i, :j, :k]),
])
transform!(tn, HyperGroup)

@test length(inds(tn, :hyper)) == 1
end
end

@testset "DiagonalReduction" begin
using Tenet: DiagonalReduction, find_diag_axes

Expand Down

0 comments on commit d0c6440

Please sign in to comment.