Skip to content

Commit

Permalink
Add lu decomposition for Tensors (#94)
Browse files Browse the repository at this point in the history
* Implement LinearAlgebra.lu decompoisition

* Add tests for the new LinearAlgebra.lu function

* Replace legacy labels for inds function

* Refactor LU decomposition

* Fix undef var in `factorinds`

* Refactor QR decomposition

* Update docstrings of `qr`,`lu`

* Refactor SVD factorization

* Fix typo

* Add factorizations to docs

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Nov 13, 2023
1 parent 13742fb commit dd9b6d3
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 149 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"

Expand Down
8 changes: 8 additions & 0 deletions docs/src/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ length(Tᵢⱼₖ)
```@docs
Tenet.contract(::Tensor, ::Tensor)
```

### Factorizations

```@docs
LinearAlgebra.svd(::Tensor)
LinearAlgebra.qr(::Tensor)
LinearAlgebra.lu(::Tensor)
```
160 changes: 107 additions & 53 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OMEinsum
using LinearAlgebra
using UUIDs: uuid4
using SparseArrays

# TODO test array container typevar on output
for op in [
Expand Down Expand Up @@ -79,89 +80,142 @@ Base.:*(a::Tensor, b::Tensor) = contract(a, b)
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a))
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b))

function factorinds(tensor, left_inds, right_inds)
isdisjoint(left_inds, right_inds) ||
throw(ArgumentError("left ($left_inds) and right $(right_inds) indices must be disjoint"))

left_inds, right_inds =
isempty(left_inds) ? (setdiff(inds(tensor), right_inds), right_inds) :
isempty(right_inds) ? (left_inds, setdiff(inds(tensor), left_inds)) :
throw(ArgumentError("cannot set both left and right indices"))

all(!isempty, (left_inds, right_inds)) || throw(ArgumentError("no right-indices left in factorization"))
all((inds(tensor)), left_inds right_inds) || throw(ArgumentError("indices must be in $(inds(tensor))"))

return left_inds, right_inds
end

LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

function LinearAlgebra.svd(t::Tensor; left_inds, kwargs...)
if isempty(left_inds)
throw(ErrorException("no left-indices in SVD factorization"))
elseif any((inds(t)), left_inds)
# TODO better error exception and checks
throw(ErrorException("all left-indices must be in $(inds(t))"))
end
"""
LinearAlgebra.svd(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)
Perform SVD factorization on a tensor.
# Keyword arguments
right_inds = setdiff(inds(t), left_inds)
if isempty(right_inds)
# TODO better error exception and checks
throw(ErrorException("no right-indices in SVD factorization"))
end
- `left_inds`: left indices to be used in the SVD factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the SVD factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.svd(tensor::Tensor; left_inds = (), right_inds = (), virtualind = Symbol(uuid4()), kwargs...)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

virtualind inds(tensor) ||
throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute SVD
U, s, V = svd(data; kwargs...)

# tensorify results
U = reshape(U, ([size(t, ind) for ind in left_inds]..., size(U, 2)))
s = Diagonal(s)
Vt = reshape(V', (size(V', 1), [size(t, ind) for ind in right_inds]...))

vlind = Symbol(uuid4())
vrind = Symbol(uuid4())

U = Tensor(U, (left_inds..., vlind))
s = Tensor(s, (vlind, vrind))
Vt = Tensor(Vt, (vrind, right_inds...))
U = Tensor(reshape(U, left_sizes..., size(U, 2)), [left_inds..., virtualind])
s = Tensor(s, [virtualind])
Vt = Tensor(reshape(V, right_sizes..., size(V, 2)), [right_inds..., virtualind])

return U, s, Vt
end

LinearAlgebra.qr(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke qr(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

"""
LinearAlgebra.qr(t::Tensor, mode::Symbol = :reduced; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...
LinearAlgebra.qr(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)
Perform QR factorization on a tensor.
# Arguments
- `t::Tensor`: tensor to be factorized
# Keyword Arguments
# Keyword arguments
- `left_inds`: left indices to be used in the QR factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the QR factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
- `left_inds`: left indices to be used in the QR factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the QR factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.qr(t::Tensor; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...)
isdisjoint(left_inds, right_inds) ||
throw(ArgumentError("left ($left_inds) and right $(right_inds) indices must be disjoint"))

left_inds, right_inds =
isempty(left_inds) ? (setdiff(inds(t), right_inds), right_inds) :
isempty(right_inds) ? (left_inds, setdiff(inds(t), left_inds)) :
throw(ArgumentError("cannot set both left and right indices"))

all(!isempty, (left_inds, right_inds)) || throw(ArgumentError("no right-indices left in QR factorization"))
all((inds(t)), left_inds right_inds) || throw(ArgumentError("indices must be in $(inds(t))"))

virtualind inds(t) || throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))
function LinearAlgebra.qr(
tensor::Tensor;
left_inds = (),
right_inds = (),
virtualind::Symbol = Symbol(uuid4()),
kwargs...,
)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

virtualind inds(tensor) ||
throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute QR
F = qr(data; kwargs...)
Q, R = Matrix(F.Q), Matrix(F.R)

# tensorify results
Q = reshape(Q, ([size(t, ind) for ind in left_inds]..., size(Q, 2)))
R = reshape(R, (size(R, 1), [size(t, ind) for ind in right_inds]...))

Q = Tensor(Q, (left_inds..., virtualind))
R = Tensor(R, (virtualind, right_inds...))
Q = Tensor(reshape(Q, left_sizes..., size(Q, 2)), [left_inds..., virtualind])
R = Tensor(reshape(R, size(R, 1), right_sizes...), [virtualind, right_inds...])

return Q, R
end

LinearAlgebra.lu(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke lu(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

"""
LinearAlgebra.lu(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)
Perform LU factorization on a tensor.
# Keyword arguments
- `left_inds`: left indices to be used in the LU factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the LU factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.lu(
tensor::Tensor;
left_inds = (),
right_inds = (),
virtualind = [Symbol(uuid4()), Symbol(uuid4())],
kwargs...,
)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

i_pl, i_lu = virtualind
i_pl inds(tensor) || throw(ArgumentError("new virtual bond name ($i_pl) cannot be already be present"))
i_lu inds(tensor) || throw(ArgumentError("new virtual bond name ($i_lu) cannot be already be present"))

# permute array
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute LU
info = lu(data; kwargs...)
L = info.L
U = info.U

permutator = info.p
P = sparse(permutator, 1:length(permutator), fill(true, length(permutator)))

L = Tensor(L, [i_pl, i_lu])
U = Tensor(reshape(U, size(U, 1), right_sizes...), [i_lu, right_inds...])
P = Tensor(reshape(P, left_sizes..., size(L, 1)), [left_inds..., i_pl])

return L, U, P
end
2 changes: 1 addition & 1 deletion src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Base.selectdim(t::Tensor, d::Symbol, i) = selectdim(t, dim(t, d), i)
Base.permutedims(t::Tensor, perm) = Tensor(permutedims(parent(t), perm), getindex.((inds(t),), perm))
Base.permutedims!(dest::Tensor, src::Tensor, perm) = permutedims!(parent(dest), parent(src), perm)

function Base.permutedims(t::Tensor{T,N}, perm::NTuple{N,Symbol}) where {T,N}
function Base.permutedims(t::Tensor{T}, perm::Base.AbstractVecOrTuple{Symbol}) where {T}
perm = map(i -> findfirst(==(i), inds(t)), perm)
permutedims(t, perm)
end
Expand Down
15 changes: 8 additions & 7 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,20 +249,21 @@ function transform!(tn::AbstractTensorNetwork, config::SplitSimplification)
bipartitions = Iterators.flatten(combinations(inds, r) for r in 1:(length(inds)-1))
for bipartition in bipartitions
left_inds = collect(bipartition)
right_inds = setdiff(inds, left_inds)

# perform an SVD across the bipartition
u, s, v = svd(tensor; left_inds = left_inds)
rank_s = sum(diag(s) .> config.atol)
rank_s = sum(s .> config.atol)

if rank_s < length(s)
hyperindex = only(Tenet.inds(s))

if rank_s < size(s, 1)
# truncate data
u = view(u, Tenet.inds(s)[1] => 1:rank_s)
s = view(s, (idx -> idx => 1:rank_s).(Tenet.inds(s))...)
v = view(v, Tenet.inds(s)[2] => 1:rank_s)
u = view(u, hyperindex => 1:rank_s)
s = view(s, hyperindex => 1:rank_s)
v = view(v, hyperindex => 1:rank_s)

# replace the original tensor with factorization
tensor_l = u * s
tensor_l = contract(u, s, dims = Symbol[])
tensor_r = v

push!(tn, dropdims(tensor_l))
Expand Down
Loading

0 comments on commit dd9b6d3

Please sign in to comment.