diff --git a/docs/Project.toml b/docs/Project.toml index b2c25ef3..55846ef9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,7 +11,7 @@ NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" [sources] -Tenet = {path = "/Users/mofeing/Developer/Tenet.jl/docs/.."} +Tenet = {path = ".."} [compat] Documenter = "1" diff --git a/docs/src/manual/ansatz/mps.md b/docs/src/manual/ansatz/mps.md index 91190139..532fa344 100644 --- a/docs/src/manual/ansatz/mps.md +++ b/docs/src/manual/ansatz/mps.md @@ -1,8 +1,9 @@ # Matrix Product States (MPS) -Matrix Product States (MPS) are a Quantum Tensor Network ansatz whose tensors are laid out in a 1D chain. -Due to this, these networks are also known as _Tensor Trains_ in other mathematical fields. -Depending on the boundary conditions, the chains can be open or closed (i.e. periodic boundary conditions). +Matrix Product States ([`MPS`](@ref)) are a Quantum Tensor Network ansatz whose tensors are laid out in a 1D chain. +Due to this, these networks are also known as _Tensor Trains_ in other scientific fields. +Depending on the boundary conditions, the chains can be open or closed (i.e. periodic boundary conditions), currently +only `Open` boundary conditions are supported in `Tenet`. ```@setup viz using Makie @@ -16,38 +17,102 @@ using NetworkLayout ``` ```@example viz -fig = Figure() # hide +fig = Figure() +open_mps = rand(MPS; n=10, maxdim=4) -tn_open = rand(MatrixProduct{State,Open}, n=10, χ=4) # hide -tn_periodic = rand(MatrixProduct{State,Periodic}, n=10, χ=4) # hide +plot!(fig[1,1], open_mps, layout=Spring(iterations=1000, C=0.5, seed=100)) +Label(fig[1,1, Bottom()], "Open") -plot!(fig[1,1], tn_open, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide -plot!(fig[1,2], tn_periodic, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide +fig +``` + +The default ordering of the indices on the `MPS` constructor is (physical, left, right), but you can specify the ordering by passing the `order` keyword argument: + +```@example +mps = MPS([rand(4, 2), rand(4, 8, 2), rand(8, 2)]; order=[:l, :r, :o]) +``` +where `:l`, `:r`, and `:o` represent the left, right, and outer physical indices, respectively. -Label(fig[1,1, Bottom()], "Open") # hide -Label(fig[1,2, Bottom()], "Periodic") # hide -fig # hide +### Canonical Forms + +An `MPS` representation is not unique: a single `MPS` can be represented in different canonical forms. The choice of canonical form can affect the efficiency and stability of algorithms used to manipulate the `MPS`. +The current form of the `MPS` is stored as the trait [`Form`](@ref) and can be accessed via the `form` function: + +```@example +mps = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + +form(mps) ``` +> :warning: Depending on the form, `Tenet` will dispatch under the hood the appropriate algorithm which assumes full use of the canonical form, so be careful when making modifications that might alter the canonical form without changing the trait. + +`Tenet` has the internal function [`Tenet.check_form`](@ref) to check if the `MPS` is in the correct canonical form. This function can be used to ensure that the `MPS` is in the correct form before performing any operation that requires it. +Currently, `Tenet` supports the [`NonCanonical`](@ref), [`CanonicalForm`](@ref) and [`MixedCanonical`](@ref) forms. + +#### `NonCanonical` Form +In the `NonCanonical` form, the tensors in the `MPS` do not satisfy any particular orthogonality conditions. This is the default `form` when an `MPS` is initialized without specifying a canonical form. It is useful for general purposes but may not be optimal for certain computations that benefit from orthogonality. + +#### `Canonical` Form +Also known as Vidal's form, the `Canonical` form represents the `MPS` using a sequence of isometric tensors (`Γ`) and diagonal vectors (`λ`) containing the Schmidt coefficients. The `MPS` is expressed as: + +```math +| \psi \rangle = \sum_{i_1, \dots, i_N} \Gamma_1^{i_1} \lambda_2 \Gamma_2^{i_2} \dots \lambda_{N-1} \Gamma_{N-1}^{i_{N-1}} \lambda_N \Gamma_N^{i_N} | i_1, \dots, i_N \rangle \, . +``` + +You can convert an `MPS` to the `Canonical` form by calling `canonize!`: + +```@example +mps = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) +canonize!(mps) + +form(mps) +``` + +#### `MixedCanonical` Form +In the `MixedCanonical` form, tensors to the left of the orthogonality center are left-canonical, tensors to the right are right-canonical, and the tensors at the orthogonality center (which can be `Site` or `Vector{<:Site}`) contains the entanglement information between the left and right parts of the chain. The position of the orthogonality center is stored in the `orthog_center` field. + +You can convert an `MPS` to the `MixedCanonical` form and specify the orthogonality center using `mixed_canonize!`. Additionally, one can check that the `MPS` is effectively in mixed canonical form using the functions `isleftcanonical` and `isrightcanonical`, which return `true` if the `Tensor` at that particular site is left or right canonical, respectively. + +```@example +mps = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) +mixed_canonize!(mps, Site(2)) + +isisometry(mps, 1; dir=:right) # Check if the first tensor is left canonical +isisometry(mps, 3; dir=:left) # Check if the third tensor is right canonical +``` + +form(mps) +``` + +##### Additional Resources +For more in-depth information on Matrix Product States and their canonical forms, you may refer to: +- Schollwöck, U. (2011). The density-matrix renormalization group in the age of matrix product states. Annals of physics, 326(1), 96-192. + ## Matrix Product Operators (MPO) -Matrix Product Operators (MPO) are the operator version of [Matrix Product State (MPS)](#matrix-product-states-mps). -The major difference between them is that MPOs have 2 indices per site (1 input and 1 output) while MPSs only have 1 index per site (i.e. an output). +Matrix Product Operators ([`MPO`](@ref)) are the operator version of [Matrix Product State (MPS)](#matrix-product-states-mps). +The major difference between them is that MPOs have 2 indices per site (1 input and 1 output) while MPSs only have 1 index per site (i.e. an output). Currently, only `Open` boundary conditions are supported in `Tenet`. ```@example viz -fig = Figure() # hide +fig = Figure() +open_mpo = rand(MPO, n=10, maxdim=4) -tn_open = rand(MatrixProduct{Operator,Open}, n=10, χ=4) # hide -tn_periodic = rand(MatrixProduct{Operator,Periodic}, n=10, χ=4) # hide +plot!(fig[1,1], open_mpo, layout=Spring(iterations=1000, C=0.5, seed=100)) +Label(fig[1,1, Bottom()], "Open") -plot!(fig[1,1], tn_open, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide -plot!(fig[1,2], tn_periodic, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide +fig +``` -Label(fig[1,1, Bottom()], "Open") # hide -Label(fig[1,2, Bottom()], "Periodic") # hide +To apply an `MPO` to an `MPS`, you can use the `evolve!` function: -fig # hide -``` +```@example +mps = rand(MPS; n=10, maxdim=100) +mpo = rand(MPO; n=10, maxdim=4) + +size.(tensors(mps)) + +evolve!(mps, mpo) -In `Tenet`, the generic `MatrixProduct` ansatz implements this topology. Type variables are used to address their functionality (`State` or `Operator`) and their boundary conditions (`Open` or `Periodic`). +size.(tensors(mps)) +``` \ No newline at end of file diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 8cdca599..a972eaf5 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -13,7 +13,7 @@ function Reactant.make_tracer( seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... ) where {RT<:Tensor} tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...) - return Tensor(tracedata, inds(prev)) + return Tensor(tracedata, copy(inds(prev))) end function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) @@ -42,16 +42,16 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac return Tenet.Product(tracetn) end -for A in (MPS, MPO) - @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return $A(tracetn, form(prev)) - end +function Reactant.make_tracer( + seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs... +) where {A<:Tenet.AbstractMPO} + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return A(tracetn, copy(form(prev))) end function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) - return :($Tensor($data, $(inds(tocopy)))) + return :($Tensor($data, $(copy(inds(tocopy))))) end function Reactant.create_result(tocopy::TensorNetwork, @nospecialize(path), result_stores) @@ -77,26 +77,11 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu return :($(Tenet.Product)($tn)) end -for A in (MPS, MPO) - @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} - tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) - return :($A($tn, $(Tenet.form(tocopy)))) - end +function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:Tenet.AbstractMPO} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn, $(Tenet.form(tocopy)))) end -# TODO try rely on generic fallback for ansatzes -# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) -# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) -# return :($(Tenet.Product)($tn)) -# end - -# for A in (MPS, MPO) -# @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores) -# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) -# return :($A($tn, form(tocopy))) -# end -# end - function Reactant.push_val!(ad_inputs, x::TensorNetwork, path) @assert length(path) == 2 @assert path[2] === :data @@ -216,7 +201,14 @@ end Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} - return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) + # TODO change to `Ops.constant` when Ops PR lands in Reactant + # apparently `promote_to` doesn't do the transpostion for converting from column-major (Julia) to row-major layout (MLIR) + # currently, we call permutedims manually + return contract( + a, + Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, permutedims(parent(b), collect(Nb:-1:1))), inds(b)); + kwargs..., + ) end end diff --git a/ext/TenetYaoBlocksExt.jl b/ext/TenetYaoBlocksExt.jl index 4f902f9b..4033d532 100644 --- a/ext/TenetYaoBlocksExt.jl +++ b/ext/TenetYaoBlocksExt.jl @@ -25,13 +25,13 @@ function Tenet.Quantum(circuit::AbstractBlock) end # NOTE `YaoBlocks.mat` on m-site qubits still returns the operator on the full Hilbert space + m = length(occupied_locs(gate)) operator = if gate isa YaoBlocks.ControlBlock - m = length(occupied_locs(gate)) control((1:(m - 1))..., m => content(gate))(m) else content(gate) end - array = reshape(mat(operator), fill(nlevel(operator), 2 * nqubits(operator))...) + array = reshape(collect(mat(operator)), fill(nlevel(operator), 2 * nqubits(operator))...) inds = (x -> collect(Iterators.flatten(zip(x...))))( map(occupied_locs(gate)) do l diff --git a/src/Ansatz.jl b/src/Ansatz.jl index fac055e0..e4a9cf43 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -33,6 +33,8 @@ Abstract type representing the canonical form trait of a [`AbstractAnsatz`](@ref """ abstract type Form end +Base.copy(x::Form) = x + """ NonCanonical @@ -52,6 +54,8 @@ struct MixedCanonical <: Form orthog_center::Union{Site,Vector{<:Site}} end +Base.copy(x::MixedCanonical) = MixedCanonical(copy(x.orthog_center)) + """ Canonical @@ -321,21 +325,21 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, return tn end -function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false) +function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; kwargs...) # move orthogonality center to bond mixed_canonize!(tn, bond) - return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize) + return truncate!(NonCanonical(), tn, bond; compute_local_svd=true, kwargs...) end """ - truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true) + truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...) Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest **Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`. """ -function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false) - truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize) +function truncate!(::Canonical, tn::AbstractAnsatz, bond; canonize=true, kwargs...) + truncate!(NonCanonical(), tn, bond; compute_local_svd=false, kwargs...) canonize && canonize!(tn) diff --git a/src/MPS.jl b/src/MPS.jl index f96b2f82..c5dfd8ac 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -131,9 +131,9 @@ end Check if the tensors in the mps are in the proper [`Form`](@ref). """ -check_form(mps::AbstractMPO) = check_form(form(mps), mps) +check_form(mps::AbstractMPO; kwargs...) = check_form(form(mps), mps; kwargs...) -function check_form(config::MixedCanonical, mps::AbstractMPO) +function check_form(config::MixedCanonical, mps::AbstractMPO; atol=1e-12) orthog_center = config.orthog_center left, right = if orthog_center isa Site @@ -144,23 +144,24 @@ function check_form(config::MixedCanonical, mps::AbstractMPO) for i in 1:nsites(mps) if i < left # Check left-canonical tensors - isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) + isisometry(mps, Site(i); dir=:right, atol) || throw(ArgumentError("Tensors are not left-canonical")) elseif i > right # Check right-canonical tensors - isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) + isisometry(mps, Site(i); dir=:left, atol) || throw(ArgumentError("Tensors are not right-canonical")) end end return true end -function check_form(::Canonical, mps::AbstractMPO) +function check_form(::Canonical, mps::AbstractMPO; atol=1e-12) for i in 1:nsites(mps) - if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right) + if i > 1 && + !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right, atol) throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction.")) end if i < nsites(mps) && - !isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left) + !isisometry(contract(mps; between=(Site(i), Site(i + 1)), direction=:left), Site(i); dir=:left, atol) throw(ArgumentError("Can not form a right-canonical tensor in Site($i) from Γ and λ contraction.")) end end @@ -541,6 +542,133 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center) return tn end +""" + evolve!(ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true) + +Evolve the [`AbstractAnsatz`](@ref) `ψ` with the [`AbstractMPO`](@ref) `mpo` along the output indices of `ψ`. +If `threshold` or `maxdim` are not `nothing`, the tensors are truncated after each sweep at the proper value, and the +bond is normalized if `normalize=true`. If `reset_index=true`, the indices of the `ψ` are reset to the original ones. +""" +function evolve!( + ψ::AbstractAnsatz, mpo::AbstractMPO; threshold=nothing, maxdim=nothing, normalize=true, reset_index=true +) + original_sites = copy(Quantum(ψ).sites) + evolve!(form(ψ), ψ, mpo; threshold, maxdim, normalize) + + if reset_index + resetindex!(ψ; init=ninds(TensorNetwork(ψ)) + 1) + + replacements = [inds(ψ; at=site) => original_sites[site] for site in keys(original_sites)] + replace!(ψ, replacements) + end + + return ψ +end + +function evolve!(::NonCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...) + L = nsites(ψ) + Tenet.@reindex! outputs(ψ) => inputs(mpo) + + right_inds = [inds(ψ; at=Site(i), dir=:right) for i in 1:(L - 1)] + + for i in 1:L + contract_ind = inds(ψ; at=Site(i)) + push!(ψ, tensors(mpo; at=Site(i))) + contract!(ψ, contract_ind) + merge!(Quantum(ψ).sites, Dict(Site(i) => inds(mpo; at=Site(i)))) + end + + # Group the parallel bond indices + for i in 1:(L - 1) + groupinds!(ψ, right_inds[i]) + end + + if !isnothing(threshold) || !isnothing(maxdim) + truncate_sweep!(form(ψ), ψ; threshold, maxdim, normalize) + else + normalize && normalize!(ψ) + end + + return ψ +end + +function evolve!(::MixedCanonical, ψ::AbstractAnsatz, mpo::AbstractMPO; normalize, kwargs...) + initial_form = form(ψ) + mixed_canonize!(ψ, Site(nsites(ψ))) # We convert all the tensors to left-canonical form + + evolve!(NonCanonical(), ψ, mpo; normalize, kwargs...) + + mixed_canonize!(ψ, initial_form.orthog_center) + + return ψ +end + +function evolve!(::Canonical, ψ::AbstractAnsatz, mpo::AbstractMPO; threshold, maxdim, normalize, kwargs...) + # We first join the λs to the Γs to get MixedCanonical(Site(1)) form + for i in 1:(nsites(ψ) - 1) + contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right) + end + + evolve!(NonCanonical(), ψ, mpo; threshold=nothing, maxdim=nothing, normalize=false, kwargs...) # set maxdim and threshold to nothing so we truncate from Canonical form + + if !isnothing(threshold) || !isnothing(maxdim) + truncate_sweep!(Canonical(), ψ; threshold, maxdim, normalize) + else + normalize && canonize!(ψ; normalize) + end + + return ψ +end + +""" + truncate_sweep! + +Do a right-to-left QR sweep on the [`AbstractMPO`](@ref) `ψ` and then left-to-right SVD sweep and truncate the tensors +according to the `threshold` or `maxdim` values. The bond is normalized if `normalize=true`. +""" +function truncate_sweep! end + +function truncate_sweep!(::NonCanonical, ψ::AbstractMPO; threshold, maxdim, normalize) + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values and truncate + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + + (!isnothing(threshold) || !isnothing(maxdim)) && + truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false) + + contract!(ψ; between=(Site(i), Site(i + 1)), direction=:right) + end + + ψ.form = MixedCanonical(Site(nsites(ψ))) + + return ψ +end + +function truncate_sweep!(::MixedCanonical, ψ::AbstractMPO; threshold, maxdim, normalize) + truncate_sweep!(NonCanonical(), ψ; threshold, maxdim, normalize) +end + +function truncate_sweep!(::Canonical, ψ::AbstractMPO; threshold, maxdim, normalize) + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values and truncate + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + (!isnothing(threshold) || !isnothing(maxdim)) && + truncate!(ψ, [Site(i), Site(i + 1)]; threshold, maxdim, normalize, compute_local_svd=false) + end + + canonize!(ψ) + + return ψ +end + LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...) LinearAlgebra.normalize!(ψ::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at) LinearAlgebra.normalize!(ψ::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond) @@ -564,14 +692,15 @@ function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=co end function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing) + old_norm = norm(ψ) if isnothing(bond) # Normalize all λ tensors for i in 1:(nsites(ψ) - 1) λ = tensors(ψ; between=(Site(i), Site(i + 1))) - replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1))) + replace!(ψ, λ => λ ./ old_norm^(1 / (nsites(ψ) - 1))) end else λ = tensors(ψ; between=bond) - replace!(ψ, λ => λ ./ norm(λ)) + replace!(ψ, λ => λ ./ old_norm) end return ψ diff --git a/src/Site.jl b/src/Site.jl index 4e865c93..a4c84294 100644 --- a/src/Site.jl +++ b/src/Site.jl @@ -15,6 +15,8 @@ end Site(id::Int; kwargs...) = Site((id,); kwargs...) Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...) +Base.copy(x::Site) = x + id(site::Site{1}) = only(site.id) id(site::Site) = site.id diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 25e67fc8..b8359a4e 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -374,6 +374,61 @@ using LinearAlgebra @test_throws ArgumentError Tenet.check_form(evolved) end end + + @testset "MPO evolution" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + normalize!(ψ) + mpo = rand(MPO; n=5, maxdim=8) + + ϕ_1 = deepcopy(ψ) + ϕ_2 = deepcopy(ψ) + ϕ_3 = deepcopy(ψ) + + @testset "NonCanonical" begin + evolve!(ϕ_1, mpo) + @test length(tensors(ϕ_1)) == 5 + @test norm(ϕ_1) ≈ 1.0 + + evolved = evolve!(deepcopy(ψ), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test norm(evolved) ≈ 1.0 + end + + @testset "Canonical" begin + canonize!(ϕ_2) + evolve!(ϕ_2, mpo) + @test length(tensors(ϕ_2)) == 5 + 4 + @test form(ϕ_2) == Canonical() + @test Tenet.check_form(ϕ_2) + + evolved = evolve!(deepcopy(canonize!(ψ)), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test form(evolved) == Canonical() + @test Tenet.check_form(evolved) + end + + @testset "MixedCanonical" begin + mixed_canonize!(ϕ_3, site"3") + evolve!(ϕ_3, mpo) + @test length(tensors(ϕ_3)) == 5 + @test form(ϕ_3) == MixedCanonical(Site(3)) + @test norm(ϕ_3) ≈ 1.0 + @test Tenet.check_form(ϕ_3) + + evolved = evolve!(deepcopy(mixed_canonize!(ψ, site"3")), mpo; maxdim=3) + @test all(x -> x ≤ 3, vcat([collect(t) for t in vec(size.(tensors(evolved)))]...)) + @test form(evolved) == MixedCanonical(Site(3)) + @test norm(evolved) ≈ 1.0 + @test Tenet.check_form(evolved) + end + + t1 = contract(ϕ_1) + t2 = contract(ϕ_2) + t3 = contract(ϕ_3) + + @test t1 ≈ t2 ≈ t3 + @test only(overlap(ϕ_1, ϕ_2)) ≈ only(overlap(ϕ_1, ϕ_3)) ≈ only(overlap(ϕ_2, ϕ_3)) ≈ 1.0 + end end # TODO rename when method is renamed