Skip to content

Commit

Permalink
Fix replace! function when replacing a Tensor with itself in `pai…
Browse files Browse the repository at this point in the history
…r` argument (#227)

* Fix replace! for same Tensors on pair argument

* Fix tests and add specific test when there is the same tensor in pair argument

* Fragment large replace! testset with smaller ones

* Change testset titles

* Change equality check to egality
  • Loading branch information
jofrevalles authored Nov 4, 2024
1 parent 23b84cc commit db86848
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 39 deletions.
3 changes: 3 additions & 0 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ end
function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor})
tn = TensorNetwork(tn)
old_tensor, new_tensor = pair

old_tensor === new_tensor && return tn

issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match"))

push!(tn, new_tensor)
Expand Down
93 changes: 54 additions & 39 deletions test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,60 +459,75 @@
end

@testset "replace tensors" begin
t_ij = Tensor(zeros(2, 2), (:i, :j))
t_ik = Tensor(zeros(2, 2), (:i, :k))
t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m))
t_lm = Tensor(zeros(2, 2), (:l, :m))
tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm])

old_tensor = t_lm
@testset "Basic replacement" begin
t_ij = Tensor(zeros(2, 2), (:i, :j))
t_ik = Tensor(zeros(2, 2), (:i, :k))
t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m))
t_lm = Tensor(zeros(2, 2), (:l, :m))
tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm])

old_tensor = t_lm

@test_throws ArgumentError begin
new_tensor = Tensor(rand(2, 2), (:a, :b))
replace!(tn, old_tensor => new_tensor)
end

@test_throws ArgumentError begin
new_tensor = Tensor(rand(2, 2), (:a, :b))
new_tensor = Tensor(rand(2, 2), (:l, :m))
replace!(tn, old_tensor => new_tensor)

@test new_tensor === only(filter(t -> issetequal(inds(t), [:l, :m]), tensors(tn)))

# Check if connections are maintained
for ind in inds(new_tensor)
tensors_with_ind = tn.indexmap[ind]
@test new_tensor tensors_with_ind
@test !(old_tensor tensors_with_ind)
end
end

new_tensor = Tensor(rand(2, 2), (:l, :m))
replace!(tn, old_tensor => new_tensor)
@testset "TensorNetwork with tensors of equal indices" begin
A = Tensor(rand(2, 2), (:u, :w))
B = Tensor(rand(2, 2), (:u, :w))
tn = TensorNetwork([A, B])

@test new_tensor === only(filter(t -> issetequal(inds(t), [:l, :m]), tensors(tn)))
new_tensor = Tensor(rand(2, 2), (:u, :w))

# Check if connections are maintained
# for label in inds(new_tensor)
# index = tn.inds[label]
# @test new_tensor in index.links
# @test !(old_tensor in index.links)
# end
replace!(tn, B => new_tensor)
@test A tensors(tn)
@test new_tensor tensors(tn)

# New tensor network with two tensors with the same inds
# A = Tensor(rand(2, 2), (:u, :w))
# B = Tensor(rand(2, 2), (:u, :w))
# tn = TensorNetwork([A, B])
tn = TensorNetwork([A, B])
replace!(tn, A => new_tensor)

# new_tensor = Tensor(rand(2, 2), (:u, :w))
@test issetequal(tensors(tn), [new_tensor, B])
end

# replace!(tn, B => new_tensor)
# @test A === tensors(tn)[1]
# @test new_tensor === tensors(tn)[2]
@testset "Sequence of replacements" begin
A = Tensor(zeros(2, 2), (:i, :j))
B = Tensor(zeros(2, 2), (:j, :k))
C = Tensor(zeros(2, 2), (:k, :l))
tn = TensorNetwork([A, B, C])

# tn = TensorNetwork([A, B])
# replace!(tn, A => new_tensor)
@test_throws ArgumentError replace!(tn, A => B, B => C, C => A)

# @test issetequal(tensors(tn), [new_tensor, B])
new_tensor = Tensor(rand(2, 2), (:i, :j))
new_tensor2 = Tensor(ones(2, 2), (:i, :j))

# # Test chain of replacements
# A = Tensor(zeros(2, 2), (:i, :j))
# B = Tensor(zeros(2, 2), (:j, :k))
# C = Tensor(zeros(2, 2), (:k, :l))
# tn = TensorNetwork([A, B, C])
replace!(tn, A => new_tensor, new_tensor => new_tensor2)
@test issetequal(tensors(tn), [new_tensor2, B, C])
end

# @test_throws ArgumentError replace!(tn, A => B, B => C, C => A)
@testset "Replace with itself" begin
A = Tensor(rand(2, 2), (:i, :j))
B = Tensor(rand(2, 2), (:j, :k))
C = Tensor(rand(2, 2), (:k, :l))
tn = TensorNetwork([A, B, C])

# new_tensor = Tensor(rand(2, 2), (:i, :j))
# new_tensor2 = Tensor(ones(2, 2), (:i, :j))
replace!(tn, A => A)

# replace!(tn, A => new_tensor, new_tensor => new_tensor2)
# @test issetequal(tensors(tn), [new_tensor2, B, C])
@test issetequal(tensors(tn), [A, B, C])
end
end

@testset "replace tensors by tensor network" begin
Expand Down

0 comments on commit db86848

Please sign in to comment.