Skip to content

Commit

Permalink
Refactor abstract methods to TensorNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 21, 2024
1 parent 87be056 commit cb9bfd7
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 65 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ makedocs(
"Visualization"=>"visualization.md",
"Alternatives"=>"alternatives.md",
"References"=>"references.md",
"Developer Notes"=>Any["`AbstractTensorNetwork` interface"=>"interface.md"],
],
format = Documenter.HTML(
prettyurls = false,
Expand Down
7 changes: 7 additions & 0 deletions docs/src/developer/interface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# `AbstractTensorNetwork` interface

| **Methods to implement** | **Brief description** |
| ------------------------ | --------------------- |
| `tensors(tn)` | |
| `inds(tn)` | |
| `size(tn)` | |
134 changes: 69 additions & 65 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,32 @@ TensorNetwork() = TensorNetwork(Tensor[])
Return a shallow copy of a [`TensorNetwork`](@ref).
"""
Base.copy(tn::T) where {T<:AbstractTensorNetwork} = TensorNetwork(tensors(tn))
Base.copy(tn::TensorNetwork) = TensorNetwork(tensors(tn))

Base.summary(io::IO, tn::AbstractTensorNetwork) = print(io, "$(length(tn.tensormap))-tensors $(typeof(tn))")
Base.show(io::IO, tn::AbstractTensorNetwork) =
Base.summary(io::IO, tn::T) where {T<:AbstractTensorNetwork} = print(io, "$(length(tensors(tn)))-tensors $(typeof(tn))")
Base.summary(io::IO, tn::TensorNetwork) = print(io, "$(length(tn.tensormap))-tensors TensorNetwork")

Base.show(io::IO, tn::T) where {T<:AbstractTensorNetwork} =
print(io, "$T(#tensors=$(length(tensors(tn))), #inds=$(length(inds(tn))))")
Base.show(io::IO, tn::TensorNetwork) =
print(io, "$(typeof(tn))(#tensors=$(length(tn.tensormap)), #inds=$(length(tn.indexmap)))")

"""
tensors(tn::AbstractTensorNetwork)
tensors(tn::TensorNetwork)
Return a list of the `Tensor`s in the [`TensorNetwork`](@ref).
# Implementation details
- As the tensors of a [`TensorNetwork`](@ref) are stored as keys of the `.tensormap` dictionary and it uses `objectid` as hash, order is not stable so it sorts for repeated evaluations.
"""
tensors(tn::AbstractTensorNetwork) = sort!(collect(keys(tn.tensormap)), by = inds)
tensors(tn::TensorNetwork) = sort!(collect(keys(tn.tensormap)), by = inds)
arrays(tn::AbstractTensorNetwork) = parent.(tensors(tn))

Base.collect(tn::AbstractTensorNetwork) = tensors(tn)

"""
inds(tn::AbstractTensorNetwork, set = :all)
inds(tn::TensorNetwork, set = :all)
Return the names of the indices in the [`TensorNetwork`](@ref).
Expand All @@ -73,46 +77,46 @@ Return the names of the indices in the [`TensorNetwork`](@ref).
+ `:inner` Indices mentioned at least twice.
+ `:hyper` Indices mentioned at least in three tensors.
"""
Tenet.inds(tn::AbstractTensorNetwork; set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...)
@valsplit 2 Tenet.inds(tn::AbstractTensorNetwork, set::Symbol, args...) = throw(MethodError(inds, "unknown set=$set"))
Tenet.inds(tn::TensorNetwork; set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...)
@valsplit 2 Tenet.inds(tn::TensorNetwork, set::Symbol, args...) = throw(MethodError(inds, "unknown set=$set"))

function Tenet.inds(tn::AbstractTensorNetwork, ::Val{:all})
function Tenet.inds(tn::TensorNetwork, ::Val{:all})
collect(keys(tn.indexmap))
end

function Tenet.inds(tn::AbstractTensorNetwork, ::Val{:open})
function Tenet.inds(tn::TensorNetwork, ::Val{:open})
map(first, Iterators.filter(((_, v),) -> length(v) == 1, tn.indexmap))
end

function Tenet.inds(tn::AbstractTensorNetwork, ::Val{:inner})
function Tenet.inds(tn::TensorNetwork, ::Val{:inner})
map(first, Iterators.filter(((_, v),) -> length(v) >= 2, tn.indexmap))
end

function Tenet.inds(tn::AbstractTensorNetwork, ::Val{:hyper})
function Tenet.inds(tn::TensorNetwork, ::Val{:hyper})
map(first, Iterators.filter(((_, v),) -> length(v) >= 3, tn.indexmap))
end

"""
size(tn::AbstractTensorNetwork)
size(tn::AbstractTensorNetwork, index)
size(tn::TensorNetwork)
size(tn::TensorNetwork, index)
Return a mapping from indices to their dimensionalities.
If `index` is set, return the dimensionality of `index`. This is equivalent to `size(tn)[index]`.
"""
Base.size(tn::AbstractTensorNetwork) = Dict{Symbol,Int}(index => size(tn, index) for index in keys(tn.indexmap))
Base.size(tn::AbstractTensorNetwork, index::Symbol) = size(first(tn.indexmap[index]), index)
Base.size(tn::TensorNetwork) = Dict{Symbol,Int}(index => size(tn, index) for index in keys(tn.indexmap))
Base.size(tn::TensorNetwork, index::Symbol) = size(first(tn.indexmap[index]), index)

Base.eltype(tn::AbstractTensorNetwork) = promote_type(eltype.(tensors(tn))...)

"""
push!(tn::AbstractTensorNetwork, tensor::Tensor)
push!(tn::TensorNetwork, tensor::Tensor)
Add a new `tensor` to the Tensor Network.
See also: [`append!`](@ref), [`pop!`](@ref).
"""
function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor)
function Base.push!(tn::TensorNetwork, tensor::Tensor)
tensor keys(tn.tensormap) && return tn

# check index sizes
Expand All @@ -129,39 +133,39 @@ function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor)
end

"""
append!(tn::AbstractTensorNetwork, tensors::AbstractVecOrTuple{<:Tensor})
append!(tn::TensorNetwork, tensors::AbstractVecOrTuple{<:Tensor})
Add a list of tensors to a `TensorNetwork`.
See also: [`push!`](@ref), [`merge!`](@ref).
"""
Base.append!(tn::AbstractTensorNetwork, tensors) = (foreach(Base.Fix1(push!, tn), tensors); tn)
Base.append!(tn::TensorNetwork, tensors) = (foreach(Base.Fix1(push!, tn), tensors); tn)

"""
merge!(self::AbstractTensorNetwork, others::AbstractTensorNetwork...)
merge(self::AbstractTensorNetwork, others::AbstractTensorNetwork...)
merge!(self::TensorNetwork, others::TensorNetwork...)
merge(self::TensorNetwork, others::TensorNetwork...)
Fuse various [`TensorNetwork`](@ref)s into one.
See also: [`append!`](@ref).
"""
Base.merge!(self::AbstractTensorNetwork, other::AbstractTensorNetwork) = append!(self, tensors(other))
Base.merge!(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) = foldl(merge!, others; init = self)
Base.merge(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) = merge!(copy(self), others...)
Base.merge!(self::TensorNetwork, other::TensorNetwork) = append!(self, tensors(other))
Base.merge!(self::TensorNetwork, others::TensorNetwork...) = foldl(merge!, others; init = self)
Base.merge(self::TensorNetwork, others::TensorNetwork...) = merge!(copy(self), others...)

"""
pop!(tn::AbstractTensorNetwork, tensor::Tensor)
pop!(tn::AbstractTensorNetwork, i::Union{Symbol,AbstractVecOrTuple{Symbol}})
pop!(tn::TensorNetwork, tensor::Tensor)
pop!(tn::TensorNetwork, i::Union{Symbol,AbstractVecOrTuple{Symbol}})
Remove a tensor from the Tensor Network and returns it. If a `Tensor` is passed, then the first tensor satisfies _egality_ (i.e. `≡` or `===`) will be removed.
If a `Symbol` or a list of `Symbol`s is passed, then remove and return the tensors that contain all the indices.
See also: [`push!`](@ref), [`delete!`](@ref).
"""
Base.pop!(tn::AbstractTensorNetwork, tensor::Tensor) = (delete!(tn, tensor); tensor)
Base.pop!(tn::AbstractTensorNetwork, i::Symbol) = pop!(tn, (i,))
Base.pop!(tn::TensorNetwork, tensor::Tensor) = (delete!(tn, tensor); tensor)
Base.pop!(tn::TensorNetwork, i::Symbol) = pop!(tn, (i,))

function Base.pop!(tn::AbstractTensorNetwork, i::AbstractVecOrTuple{Symbol})::Vector{Tensor}
function Base.pop!(tn::TensorNetwork, i::AbstractVecOrTuple{Symbol})::Vector{Tensor}
tensors = select(tn, i)
for tensor in tensors
_ = pop!(tn, tensor)
Expand All @@ -171,15 +175,15 @@ function Base.pop!(tn::AbstractTensorNetwork, i::AbstractVecOrTuple{Symbol})::Ve
end

"""
delete!(tn::AbstractTensorNetwork, x)
delete!(tn::TensorNetwork, x)
Like [`pop!`](@ref) but return the [`TensorNetwork`](@ref) instead.
"""
Base.delete!(tn::AbstractTensorNetwork, x) = (_ = pop!(tn, x); tn)
Base.delete!(tn::TensorNetwork, x) = (_ = pop!(tn, x); tn)

tryprune!(tn::AbstractTensorNetwork, i::Symbol) = (x = isempty(tn.indexmap[i]) && delete!(tn.indexmap, i); x)
tryprune!(tn::TensorNetwork, i::Symbol) = (x = isempty(tn.indexmap[i]) && delete!(tn.indexmap, i); x)

function Base.delete!(tn::AbstractTensorNetwork, tensor::Tensor)
function Base.delete!(tn::TensorNetwork, tensor::Tensor)
for index in unique(inds(tensor))
filter!(Base.Fix1(!==, tensor), tn.indexmap[index])
tryprune!(tn, index)
Expand All @@ -190,25 +194,25 @@ function Base.delete!(tn::AbstractTensorNetwork, tensor::Tensor)
end

"""
replace!(tn::AbstractTensorNetwork, old => new...)
replace(tn::AbstractTensorNetwork, old => new...)
replace!(tn::TensorNetwork, old => new...)
replace(tn::TensorNetwork, old => new...)
Replace the element in `old` with the one in `new`. Depending on the types of `old` and `new`, the following behaviour is expected:
- If `Symbol`s, it will correspond to a index renaming.
- If `Tensor`s, first element that satisfies _egality_ (`≡` or `===`) will be replaced.
"""
Base.replace!(tn::AbstractTensorNetwork, old_new::Pair...) = replace!(tn, old_new)
function Base.replace!(tn::AbstractTensorNetwork, old_new::Base.AbstractVecOrTuple{Pair})
Base.replace!(tn::TensorNetwork, old_new::Pair...) = replace!(tn, old_new)
function Base.replace!(tn::TensorNetwork, old_new::Base.AbstractVecOrTuple{Pair})
for pair in old_new
replace!(tn, pair)
end
return tn
end
Base.replace(tn::AbstractTensorNetwork, old_new::Pair...) = replace(tn, old_new)
Base.replace(tn::AbstractTensorNetwork, old_new) = replace!(copy(tn), old_new)
Base.replace(tn::TensorNetwork, old_new::Pair...) = replace(tn, old_new)
Base.replace(tn::TensorNetwork, old_new) = replace!(copy(tn), old_new)

function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor})
function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor})
old_tensor, new_tensor = pair
issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match"))

Expand All @@ -218,7 +222,7 @@ function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor})
return tn
end

function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol}...)
function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}...)
first.(old_new) keys(tn.indexmap) ||
throw(ArgumentError("set of old indices must be a subset of current indices"))
isdisjoint(last.(old_new), keys(tn.indexmap)) ||
Expand All @@ -229,7 +233,7 @@ function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol}..
return tn
end

function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol})
function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol})
old, new = old_new
old keys(tn.indexmap) || throw(ArgumentError("index $old does not exist"))
new keys(tn.indexmap) || throw(ArgumentError("index $new is already present"))
Expand All @@ -246,7 +250,7 @@ function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol})
return tn
end

function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{<:Tensor,<:AbstractTensorNetwork})
function Base.replace!(tn::TensorNetwork, old_new::Pair{<:Tensor,<:TensorNetwork})
old, new = old_new
issetequal(inds(new, set = :open), inds(old)) || throw(ArgumentError("indices don't match match"))

Expand All @@ -260,12 +264,12 @@ function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{<:Tensor,<:Abstr
end

"""
select(tn::AbstractTensorNetwork, i)
select(tn::TensorNetwork, i)
Return tensors whose indices match with the list of indices `i`.
"""
select(tn::AbstractTensorNetwork, i::Symbol) = copy(tn.indexmap[i])
select(tn::AbstractTensorNetwork, is::AbstractVecOrTuple{Symbol}) = select(, tn, is)
select(tn::TensorNetwork, i::Symbol) = copy(tn.indexmap[i])
select(tn::TensorNetwork, is::AbstractVecOrTuple{Symbol}) = select(, tn, is)

function select(selector, tn::TensorNetwork, is::AbstractVecOrTuple{Symbol})
filter(Base.Fix1(selector, is) inds, tn.indexmap[first(is)])
Expand All @@ -276,23 +280,23 @@ function Base.getindex(tn::TensorNetwork, is::Symbol...; mul::Int = 1)
end

"""
in(tensor::Tensor, tn::AbstractTensorNetwork)
in(index::Symbol, tn::AbstractTensorNetwork)
in(tensor::Tensor, tn::TensorNetwork)
in(index::Symbol, tn::TensorNetwork)
Return `true` if there is a `Tensor` in `tn` for which `==` evaluates to `true`.
This method is equivalent to `tensor ∈ tensors(tn)` code, but it's faster on large amount of tensors.
"""
Base.in(tensor::Tensor, tn::AbstractTensorNetwork) = tensor keys(tn.tensormap)
Base.in(index::Symbol, tn::AbstractTensorNetwork) = index keys(tn.indexmap)
Base.in(tensor::Tensor, tn::TensorNetwork) = tensor keys(tn.tensormap)
Base.in(index::Symbol, tn::TensorNetwork) = index keys(tn.indexmap)

"""
slice!(tn::AbstractTensorNetwork, index::Symbol, i)
slice!(tn::TensorNetwork, index::Symbol, i)
In-place projection of `index` on dimension `i`.
See also: [`selectdim`](@ref), [`view`](@ref).
"""
function slice!(tn::AbstractTensorNetwork, label::Symbol, i)
function slice!(tn::TensorNetwork, label::Symbol, i)
for tensor in pop!(tn, label)
push!(tn, selectdim(tensor, label, i))
end
Expand All @@ -301,23 +305,23 @@ function slice!(tn::AbstractTensorNetwork, label::Symbol, i)
end

"""
selectdim(tn::AbstractTensorNetwork, index::Symbol, i)
selectdim(tn::TensorNetwork, index::Symbol, i)
Return a copy of the [`TensorNetwork`](@ref) where `index` has been projected to dimension `i`.
See also: [`view`](@ref), [`slice!`](@ref).
"""
Base.selectdim(tn::AbstractTensorNetwork, label::Symbol, i) = @view tn[label=>i]
Base.selectdim(tn::TensorNetwork, label::Symbol, i) = @view tn[label=>i]

"""
view(tn::AbstractTensorNetwork, index => i...)
view(tn::TensorNetwork, index => i...)
Return a copy of the [`TensorNetwork`](@ref) where each `index` has been projected to dimension `i`.
It is equivalent to a recursive call of [`selectdim`](@ref).
See also: [`selectdim`](@ref), [`slice!`](@ref).
"""
function Base.view(tn::AbstractTensorNetwork, slices::Pair{Symbol}...)
function Base.view(tn::TensorNetwork, slices::Pair{Symbol}...)
tn = copy(tn)

for (label, i) in slices
Expand Down Expand Up @@ -409,13 +413,13 @@ EinExprs.einexpr(tn::AbstractTensorNetwork; optimizer = Greedy, outputs = inds(t
# TODO sequence of indices?
# TODO what if parallel neighbour indices?
"""
contract!(tn::AbstractTensorNetwork, index)
contract!(tn::TensorNetwork, index)
In-place contraction of tensors connected to `index`.
See also: [`contract`](@ref).
"""
function contract!(tn::AbstractTensorNetwork, i)
function contract!(tn::TensorNetwork, i)
tensor = reduce(pop!(tn, i)) do acc, tensor
contract(acc, tensor, dims = i)
end
Expand All @@ -425,25 +429,25 @@ function contract!(tn::AbstractTensorNetwork, i)
end

"""
contract(tn::AbstractTensorNetwork; kwargs...)
contract(tn::TensorNetwork; kwargs...)
Contract a [`TensorNetwork`](@ref). The contraction order will be first computed by [`einexpr`](@ref).
The `kwargs` will be passed down to the [`einexpr`](@ref) function.
See also: [`einexpr`](@ref), [`contract!`](@ref).
"""
function contract(tn::AbstractTensorNetwork; path = einexpr(tn))
function contract(tn::TensorNetwork; path = einexpr(tn))
length(path.args) == 0 && return tn[inds(path)...]

intermediates = map(subpath -> contract(tn; path = subpath), path.args)
contract(intermediates...; dims = suminds(path))
end

contract!(t::Tensor, tn::AbstractTensorNetwork; kwargs...) = contract!(tn, t; kwargs...)
contract!(tn::AbstractTensorNetwork, t::Tensor; kwargs...) = (push!(tn, t); contract(tn; kwargs...))
contract(t::Tensor, tn::AbstractTensorNetwork; kwargs...) = contract(tn, t; kwargs...)
contract(tn::AbstractTensorNetwork, t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...)
contract!(t::Tensor, tn::TensorNetwork; kwargs...) = contract!(tn, t; kwargs...)
contract!(tn::TensorNetwork, t::Tensor; kwargs...) = (push!(tn, t); contract(tn; kwargs...))
contract(t::Tensor, tn::TensorNetwork; kwargs...) = contract(tn, t; kwargs...)
contract(tn::TensorNetwork, t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...)

struct TNSampler{T<:AbstractTensorNetwork} <: Random.Sampler{T}
config::Dict{Symbol,Any}
Expand Down

0 comments on commit cb9bfd7

Please sign in to comment.