Skip to content

Commit

Permalink
Fix contract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Aug 29, 2023
1 parent 83b2252 commit 224c1d2
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TenetQuacExt = "Quac"
Bijections = "0.1"
Combinatorics = "1.0"
DeltaArrays = "0.1.1"
EinExprs = "0.5.1"
EinExprs = "0.5.2"
GraphMakie = "0.4,0.5"
Graphs = "1.7"
Makie = "0.18, 0.19"
Expand Down
22 changes: 17 additions & 5 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,15 @@ EinExprs.einexpr(tn::TensorNetwork; optimizer = Greedy, outputs = inds(tn, :open
# TODO sequence of indices?
# TODO what if parallel neighbour indices?
"""
contract!(tn::TensorNetwork, index::Symbol)
contract!(tn::TensorNetwork, index)
In-place contraction of tensors connected to `index`.
See also: [`contract`](@ref).
"""
function contract!(tn::TensorNetwork, i::Symbol)
function contract!(tn::TensorNetwork, i)
tensor = reduce(pop!(tn, i)) do acc, tensor
contract(acc, tensor, i)
contract(acc, tensor, dims = i)
end

push!(tn, tensor)
Expand All @@ -480,10 +480,22 @@ The `kwargs` will be passed down to the [`einexpr`](@ref) function.
See also: [`einexpr`](@ref), [`contract!`](@ref).
"""
contract(tn::TensorNetwork; outputs = inds(tn, :open), kwargs...) = contract(einexpr(tn; outputs = outputs, kwargs...))
function contract(tn::TensorNetwork; kwargs...)
path = einexpr(tn; kwargs...)

tn = copy(tn)

for indices in contractorder(path)
contract!(tn, indices)
end

tensors(tn) |> only
end

contract!(t::Tensor, tn::TensorNetwork; kwargs...) = contract!(tn, t; kwargs...)
contract!(tn::TensorNetwork, t::Tensor; kwargs...) = (push!(tn, t); contract(tn; kwargs...))
contract(t::Tensor, tn::TensorNetwork; kwargs...) = contract(tn, t; kwargs...)
contract(tn::TensorNetwork, t::Tensor; kwargs...) = (tn = copy(tn); push!(tn, t); contract(tn; kwargs...))
contract(tn::TensorNetwork, t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...)

struct TNSampler{A<:Ansatz,NT<:NamedTuple} <: Random.Sampler{TensorNetwork{A}}
parameters::NT
Expand Down
2 changes: 1 addition & 1 deletion test/MatrixProductOperator_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@

@testset "norm" begin
mpo = rand(MatrixProduct{Operator,Open}, n = 8, p = 2, χ = 8)
@test_skip norm(mpo) 1
@test norm(mpo) 1
end

# @testset "Initialization" begin
Expand Down
2 changes: 1 addition & 1 deletion test/MatrixProductState_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,6 @@

@testset "norm" begin
mps = rand(MatrixProduct{State,Open}, n = 8, p = 2, χ = 8)
@test_skip norm(mps) 1
@test norm(mps) 1
end
end
4 changes: 2 additions & 2 deletions test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@

@testset "contract" begin
tn = rand(TensorNetwork, 5, 3)
@test_skip contract(tn) isa Tensor
@test contract(tn) isa Tensor

A = Tensor(rand(2, 2, 2), (:i, :j, :k))
B = Tensor(rand(2, 2, 2), (:k, :l, :m))
tn = TensorNetwork([A, B])
@test_skip contract(tn) isa Tensor
@test contract(tn) isa Tensor
end

@testset "Base.replace!" begin
Expand Down
14 changes: 7 additions & 7 deletions test/Transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
)

# Test that the resulting contraction returns the same as the original
@test_skip contract(reduced) contract(tn)
@test contract(reduced) contract(tn)
end

@testset "openinds" begin
Expand Down Expand Up @@ -111,7 +111,7 @@
end

# Test that the resulting contraction returns the same as the original
@test_skip contract(reduced) contract(tn)
@test contract(reduced) contract(tn)
end
end

Expand All @@ -136,7 +136,7 @@
@test length(tensors(reduced)) length(tensors(tn))

# Test that the resulting contraction contains the same as the original
@test_skip contract(reduced) contract(tn)
@test contract(reduced) contract(tn)
end

@testset "AntiDiagonalGauging" begin
Expand Down Expand Up @@ -185,7 +185,7 @@
end

# Test that the resulting contraction is the same as the original
@test_skip contract(gauged) contract(tn)
@test contract(gauged) contract(tn)
end

@testset "ColumnReduction" begin
Expand Down Expand Up @@ -214,7 +214,7 @@
@test length(tn.indices) > length(reduced.indices)

# Test that the resulting contraction is the same as the original
@test_skip contract(reduced) contract(contract(A, B; dims = []), C)
@test contract(reduced) contract(contract(A, B; dims = []), C)
end

@testset "index size reduction" begin
Expand All @@ -239,7 +239,7 @@
@test length(tn.indices) == length(reduced.indices)

# Test that the resulting contraction is the same as the original
@test_skip contract(reduced) view(contract(tn), :j => 1:2:3)
@test contract(reduced) view(contract(tn), :j => 1:2:3)
end
end

Expand All @@ -266,6 +266,6 @@
@test smallest_deleted > largest_new

# Test that the resulting contraction is the same as the original
@test_skip contract(reduced) contract(tn)
@test contract(reduced) contract(tn)
end
end

0 comments on commit 224c1d2

Please sign in to comment.