Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Form-specific constructors for MPS #248

Merged
merged 9 commits into from
Nov 15, 2024
3 changes: 3 additions & 0 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ struct NonCanonical <: Form end
MixedCanonical

[`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form.

- The orthogonality center is a [`Site`](@ref) or a vector of [`Site`](@ref)s. The tensors to the
left of the orthogonality center are left-canonical and the tensors to the right are right-canonical.
"""
struct MixedCanonical <: Form
orthog_center::Union{Site,Vector{Site}}
Expand Down
150 changes: 148 additions & 2 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x))
defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r)
defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r)

function MPS(
arrays::Vector{<:AbstractArray}; order=defaultorder(MPS), form::Form=NonCanonical(), check_canonical_form=true
)
return MPS(form, arrays; order=order, check_canonical_form=check_canonical_form)
end
function MPS(
arrays::Vector{<:AbstractArray},
λ::Vector{<:AbstractArray};
order=defaultorder(MPS),
form::Form=Canonical(),
check_canonical_form=true,
)
return MPS(form, arrays, λ; order=order, check_canonical_form=check_canonical_form)
end

jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
"""
MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))

Expand All @@ -47,7 +62,7 @@ Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays.

- `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`.
"""
function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))
function MPS(::NonCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
Expand Down Expand Up @@ -92,6 +107,136 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS))
return MPS(ansatz, NonCanonical())
end

function MPS(form::MixedCanonical, arrays; order=defaultorder(MPS), check_canonical_form=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
issetequal(order, defaultorder(MPS)) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))"))

n = length(arrays)
gen = IndexCounter()
symbols = [nextindex!(gen) for _ in 1:(2n)]

tn = TensorNetwork(
map(enumerate(arrays)) do (i, array)
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
filter(x -> x != :r, order)
else
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n + mod1(i, n)]
elseif dir == :l
symbols[n + mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end,
)

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
mps = MPS(ansatz, form)

# Check that for site start to orthog_center-1 the tensors are left-canonical
if check_canonical_form
for i in 1:(id(form.orthog_center) - 1)
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
end

# Check that for site orthog_center+1 to end the tensors are right-canonical
for i in (id(form.orthog_center) + 1):nsites(mps)
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
end
end

return mps
end

function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check_canonical_form=true)
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"

@assert length(λ) == length(arrays) - 1 "Number of λ tensors must be one less than the number of arrays"
@assert all(==(1) ∘ ndims, λ) "All λ tensors must be Vectors"

issetequal(order, defaultorder(MPS)) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))"))

n = length(arrays)
gen = IndexCounter()
symbols = [nextindex!(gen) for _ in 1:(2n)]

# Create tensors from 'arrays'
tensor_list = map(enumerate(arrays)) do (i, array)
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
filter(x -> x != :r, order)
else
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n + mod1(i, n)]
elseif dir == :l
symbols[n + mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

# Create tensors from 'λ'
lambda_tensors = map(enumerate(λ)) do (i, array)
Tensor(array, [symbols[n + mod1(i, n)]])
end

tn = TensorNetwork(vcat(tensor_list, lambda_tensors))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
qtn = Quantum(tn, sitemap)
graph = path_graph(n)
mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))])
lattice = Lattice(mapping, graph)
ansatz = Ansatz(qtn, lattice)
mps = MPS(ansatz, Canonical())

# Check canonical form by contracting Γ and λ tensors and checking their orthogonality
if check_canonical_form
for i in 1:nsites(mps)
if i > 1
isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) ||
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

if i < nsites(mps)
isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) ||
throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction."))
end
end
end

return mps
end

"""
MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO))

Expand Down Expand Up @@ -238,7 +383,8 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float
arrays[1] = reshape(arrays[1], p, p)
arrays[n] = reshape(arrays[n], p, p)

return MPS(arrays; order=(:l, :o, :r))
return MPS(arrays; order=(:l, :o, :r), form=MixedCanonical(Site(0)))
# return MPS(arrays; order=(:l, :o, :r))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
end

# TODO different input/output physical dims
Expand Down
Loading