diff --git a/src/Ansatz.jl b/src/Ansatz.jl index b496b111..5a6bc728 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -49,7 +49,7 @@ struct NonCanonical <: Form end left of the orthogonality center are left-canonical and the tensors to the right are right-canonical. """ struct MixedCanonical <: Form - orthog_center::Union{Site,Vector{Site}} + orthog_center::Union{Site,Vector{<:Site}} end """ @@ -255,8 +255,8 @@ Truncate the dimension of the virtual `bond`` of an [`Ansatz`](@ref) Tensor Netw - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. """ -function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) - return truncate!(form(tn), tn, bond; threshold, maxdim) +function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing, kwargs...) + return truncate!(form(tn), tn, bond; threshold, maxdim, kwargs...) end """ @@ -312,11 +312,21 @@ end function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim) # move orthogonality center to bond mixed_canonize!(tn, bond) - return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) + return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true) end -function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim) - return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) +""" + truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true) + +Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest +**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`. +""" +function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true) + truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) + + recanonize && canonize!(tn) + + return tn end overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) diff --git a/src/MPS.jl b/src/MPS.jl index 0a6e69c3..7efef4a4 100644 --- a/src/MPS.jl +++ b/src/MPS.jl @@ -126,14 +126,26 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true) return mps end +""" + check_form(mps::AbstractMPO) + +Check if the tensors in the mps are in the proper [`Form`](@ref). +""" check_form(mps::AbstractMPO) = check_form(form(mps), mps) function check_form(config::MixedCanonical, mps::AbstractMPO) orthog_center = config.orthog_center + + left, right = if orthog_center isa Site + id(orthog_center) .+ (0, 0) # So left and right get the same value + elseif orthog_center isa Vector{<:Site} + extrema(id.(orthog_center)) + end + for i in 1:nsites(mps) - if i < id(orthog_center) # Check left-canonical tensors + if i < left # Check left-canonical tensors isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical")) - elseif i > id(orthog_center) # Check right-canonical tensors + elseif i > right # Check right-canonical tensors isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical")) end end @@ -156,6 +168,8 @@ function check_form(::Canonical, mps::AbstractMPO) return true end +check_form(::NonCanonical, mps::AbstractMPO) = true + """ MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) @@ -504,19 +518,24 @@ end # TODO dispatch on form # TODO generalize to AbstractAnsatz function mixed_canonize!(tn::AbstractMPO, orthog_center) + left, right = if orthog_center isa Site + id(orthog_center) .+ (-1, 1) + elseif orthog_center isa Vector{<:Site} + extrema(id.(orthog_center)) .+ (-1, 1) + else + throw(ArgumentError("`orthog_center` must be a `Site` or a `Vector{Site}`")) + end + # left-to-right QR sweep (left-canonical tensors) - for i in 1:(id(orthog_center) - 1) + for i in 1:left canonize_site!(tn, Site(i); direction=:right, method=:qr) end # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:(id(orthog_center) + 1) + for i in nsites(tn):-1:right canonize_site!(tn, Site(i); direction=:left, method=:qr) end - # center SVD sweep to get singular values - # canonize_site!(tn, orthog_center; direction=:left, method=:svd) - tn.form = MixedCanonical(orthog_center) return tn diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 18027ac8..1dd48015 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -95,19 +95,36 @@ using LinearAlgebra end @testset "truncate!" begin - ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - canonize_site!(ψ, Site(2); direction=:right, method=:svd) + @testset "NonCanonical" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(ψ, Site(2); direction=:right, method=:svd) + + truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + singular_values = tensors(ψ; between=(site"2", site"3")) + truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + # If maxdim > size(spectrum), the bond dimension is not truncated + truncated = truncate(ψ, [site"2", site"3"]; maxdim=4) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2 + end + + @testset "Canonical" begin + ψ = rand(MPS; n=5, maxdim=16) + canonize!(ψ) - truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) - @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + truncated = truncate(ψ, [site"2", site"3"]; maxdim=2) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2 + end - singular_values = tensors(ψ; between=(site"2", site"3")) - truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) - @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + @testset "MixedCanonical" begin + ψ = rand(MPS; n=5, maxdim=16) - # If maxdim > size(spectrum), the bond dimension is not truncated - truncated = truncate(ψ, [site"2", site"3"]; maxdim=4) - @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2 + truncated = truncate(ψ, [site"2", site"3"]; maxdim=3) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 3 + end end @testset "norm" begin @@ -210,18 +227,36 @@ using LinearAlgebra end @testset "mixed_canonize!" begin - ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(ψ, site"3") + @testset "single Site" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, site"3") + @test Tenet.check_form(canonized) + + @test form(canonized) isa MixedCanonical + @test form(canonized).orthog_center == site"3" - @test form(canonized) isa MixedCanonical - @test form(canonized).orthog_center == site"3" + @test isisometry(canonized, site"1"; dir=:right) + @test isisometry(canonized, site"2"; dir=:right) + @test isisometry(canonized, site"4"; dir=:left) + @test isisometry(canonized, site"5"; dir=:left) + + @test contract(canonized) ≈ contract(ψ) + end - @test isisometry(canonized, site"1"; dir=:right) - @test isisometry(canonized, site"2"; dir=:right) - @test isisometry(canonized, site"4"; dir=:left) - @test isisometry(canonized, site"5"; dir=:left) + @testset "multiple Sites" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, [site"2", site"3"]) - @test contract(canonized) ≈ contract(ψ) + @test Tenet.check_form(canonized) + @test form(canonized) isa MixedCanonical + @test form(canonized).orthog_center == [site"2", site"3"] + + @test isisometry(canonized, site"1"; dir=:right) + @test isisometry(canonized, site"4"; dir=:left) + @test isisometry(canonized, site"5"; dir=:left) + + @test contract(canonized) ≈ contract(ψ) + end end @testset "expect" begin