Skip to content

Commit

Permalink
Speedup RankSimplification
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 28, 2024
1 parent 21b110c commit ceae9fa
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
5 changes: 5 additions & 0 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Return the names of the indices in the [`TensorNetwork`](@ref).
+ `:open` Indices only mentioned in one tensor.
+ `:inner` Indices mentioned at least twice.
+ `:hyper` Indices mentioned at least in three tensors.
+ `:parallel` Indices parallel to `i` in the graph (`i` included).
"""
Tenet.inds(tn::TensorNetwork; set::Symbol=:all, kwargs...) = inds(tn, set; kwargs...)
@valsplit 2 Tenet.inds(tn::TensorNetwork, set::Symbol, args...) = throw(MethodError(inds, "unknown set=$set"))
Expand All @@ -111,6 +112,10 @@ function Tenet.inds(tn::TensorNetwork, ::Val{:hyper})
return map(first, Iterators.filter(((_, v),) -> length(v) >= 3, tn.indexmap))
end

function Tenet.inds(tn::TensorNetwork, ::Val{:parallel}, i::Symbol)
return mapreduce(inds, , select(tn, :containing, i))
end

"""
size(tn::TensorNetwork)
size(tn::TensorNetwork, index)
Expand Down
82 changes: 43 additions & 39 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,49 @@ function transform!(tn::TensorNetwork, ::HyperGroup)
end
end

"""
RankSimplification <: Transformation
Preemptively contract tensors whose result doesn't increase in size.
"""
@kwdef struct RankSimplification <: Transformation
minimize::Symbol = :length

function RankSimplification(minimize::Symbol)
@assert minimize in (:length, :rank)
return new(minimize)
end
end

function transform!(tn::TensorNetwork, config::RankSimplification)
# select indices that benefit from contraction
targets = filter(inds(tn; set=:inner)) do index
candidate_tensors = select(tn, :containing, index)

# check that the contraction minimizes the size/rank
result = sum([
EinExpr(inds(tensor), Dict(index => size(tensor, index) for index in inds(tensor))) for
tensor in candidate_tensors
])

if config.minimize == :rank
return ndims(result) <= minimum(ndims, candidate_tensors)
end

return length(result) <= minimum(length, candidate_tensors)
end

# group parallel indices
targets = unique(Iterators.map(x -> inds(tn, :parallel, x), targets))

# contract target indices
for target in targets
contract!(tn, target)
end

return tn
end

"""
DiagonalReduction <: Transformation
Expand Down Expand Up @@ -152,45 +195,6 @@ function transform!(tn::TensorNetwork, config::DiagonalReduction)
return tn
end

"""
RankSimplification <: Transformation
Preemptively contract tensors whose result doesn't increase in size.
"""
struct RankSimplification <: Transformation end

function transform!(tn::TensorNetwork, ::RankSimplification)
@label rank_transformation_start
for tensor in tensors(tn)
# TODO replace this code for `neighbours` method
connected_tensors = mapreduce(label -> select(tn, :any, label), , inds(tensor))
filter!(!=(tensor), connected_tensors)

for c_tensor in connected_tensors
# TODO keep output inds?
path = sum([
EinExpr(inds(tensor), Dict(index => size(tensor, index) for index in inds(tensor))) for
tensor in [tensor, c_tensor]
])

# Check if contraction does not increase the rank
EinExprs.removedsize(path) < 0 && continue

new_tensor = contract(tensor, c_tensor)

# Update tensor network
push!(tn, new_tensor)
delete!(tn, tensor)
delete!(tn, c_tensor)

# Break the loop since we modified the network and need to recheck connections
@goto rank_transformation_start
end
end

return tn
end

"""
AntiDiagonalGauging <: Transformation
Expand Down

0 comments on commit ceae9fa

Please sign in to comment.