From ceae9fa7849e3cda015d0ab7a3081e82d1050b8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 28 Apr 2024 15:16:32 +0200 Subject: [PATCH] Speedup `RankSimplification` --- src/TensorNetwork.jl | 5 +++ src/Transformations.jl | 82 ++++++++++++++++++++++-------------------- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 729dade90..8e04b58cf 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -91,6 +91,7 @@ Return the names of the indices in the [`TensorNetwork`](@ref). + `:open` Indices only mentioned in one tensor. + `:inner` Indices mentioned at least twice. + `:hyper` Indices mentioned at least in three tensors. + + `:parallel` Indices parallel to `i` in the graph (`i` included). """ Tenet.inds(tn::TensorNetwork; set::Symbol=:all, kwargs...) = inds(tn, set; kwargs...) @valsplit 2 Tenet.inds(tn::TensorNetwork, set::Symbol, args...) = throw(MethodError(inds, "unknown set=$set")) @@ -111,6 +112,10 @@ function Tenet.inds(tn::TensorNetwork, ::Val{:hyper}) return map(first, Iterators.filter(((_, v),) -> length(v) >= 3, tn.indexmap)) end +function Tenet.inds(tn::TensorNetwork, ::Val{:parallel}, i::Symbol) + return mapreduce(inds, ∩, select(tn, :containing, i)) +end + """ size(tn::TensorNetwork) size(tn::TensorNetwork, index) diff --git a/src/Transformations.jl b/src/Transformations.jl index 03413de13..ef956bb21 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -103,6 +103,49 @@ function transform!(tn::TensorNetwork, ::HyperGroup) end end +""" + RankSimplification <: Transformation + +Preemptively contract tensors whose result doesn't increase in size. +""" +@kwdef struct RankSimplification <: Transformation + minimize::Symbol = :length + + function RankSimplification(minimize::Symbol) + @assert minimize in (:length, :rank) + return new(minimize) + end +end + +function transform!(tn::TensorNetwork, config::RankSimplification) + # select indices that benefit from contraction + targets = filter(inds(tn; set=:inner)) do index + candidate_tensors = select(tn, :containing, index) + + # check that the contraction minimizes the size/rank + result = sum([ + EinExpr(inds(tensor), Dict(index => size(tensor, index) for index in inds(tensor))) for + tensor in candidate_tensors + ]) + + if config.minimize == :rank + return ndims(result) <= minimum(ndims, candidate_tensors) + end + + return length(result) <= minimum(length, candidate_tensors) + end + + # group parallel indices + targets = unique(Iterators.map(x -> inds(tn, :parallel, x), targets)) + + # contract target indices + for target in targets + contract!(tn, target) + end + + return tn +end + """ DiagonalReduction <: Transformation @@ -152,45 +195,6 @@ function transform!(tn::TensorNetwork, config::DiagonalReduction) return tn end -""" - RankSimplification <: Transformation - -Preemptively contract tensors whose result doesn't increase in size. -""" -struct RankSimplification <: Transformation end - -function transform!(tn::TensorNetwork, ::RankSimplification) - @label rank_transformation_start - for tensor in tensors(tn) - # TODO replace this code for `neighbours` method - connected_tensors = mapreduce(label -> select(tn, :any, label), ∪, inds(tensor)) - filter!(!=(tensor), connected_tensors) - - for c_tensor in connected_tensors - # TODO keep output inds? - path = sum([ - EinExpr(inds(tensor), Dict(index => size(tensor, index) for index in inds(tensor))) for - tensor in [tensor, c_tensor] - ]) - - # Check if contraction does not increase the rank - EinExprs.removedsize(path) < 0 && continue - - new_tensor = contract(tensor, c_tensor) - - # Update tensor network - push!(tn, new_tensor) - delete!(tn, tensor) - delete!(tn, c_tensor) - - # Break the loop since we modified the network and need to recheck connections - @goto rank_transformation_start - end - end - - return tn -end - """ AntiDiagonalGauging <: Transformation