Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance simple_update! for MPS in the Canonical form #255

Merged
merged 26 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
24674da
First round of fixes on simple_update
jofrevalles Nov 19, 2024
5f158d2
Fix tests
jofrevalles Nov 20, 2024
99e6533
Add simple_update_2site! for MixedCanonical form
jofrevalles Nov 20, 2024
07a2af9
Format code
jofrevalles Nov 20, 2024
4c28ab9
Renormalize mps in truncation when recanonize kwarg is true
jofrevalles Nov 20, 2024
7612188
Enhance tests
jofrevalles Nov 20, 2024
967d711
Change default recanonize kwarg to false in truncate! function
jofrevalles Nov 20, 2024
cfd2a0f
Refactor normalize functions
jofrevalles Nov 21, 2024
4e6835d
Enhance normalize tests
jofrevalles Nov 21, 2024
b4d678e
Define LinearAlgebra.normalize for AbstractQuantum
jofrevalles Nov 21, 2024
64c77ef
Fix normalize functions for MPS
jofrevalles Nov 21, 2024
52b30c4
Enhance tests
jofrevalles Nov 21, 2024
4db9ccb
Fix normalize for Canonical MPS
jofrevalles Nov 21, 2024
bde48f7
Format code
jofrevalles Nov 21, 2024
c9cf7d7
Update normalization step on evolve
jofrevalles Nov 21, 2024
7232631
Change normalization to all lambdas for Canonical form
jofrevalles Nov 21, 2024
58f5e32
Format code
jofrevalles Nov 21, 2024
9fc52bd
Fix truncate by adding renormalize kwarg
jofrevalles Nov 21, 2024
6b76a71
Small enhancements on normalize! functions
jofrevalles Nov 21, 2024
2948ffd
Enhance tests
jofrevalles Nov 21, 2024
3ffd4e0
Change default kwargs in truncate
jofrevalles Nov 21, 2024
f844ac0
Fix evolve kwargs
jofrevalles Nov 21, 2024
a0b0267
Fix normalize! by putting replace! instead of inplace modification fo…
jofrevalles Nov 21, 2024
5f97d6b
Enhance tests
jofrevalles Nov 21, 2024
03e886b
Fix aesthetic suggestions, improve kwarg definition
jofrevalles Nov 22, 2024
2c46309
Update comment
jofrevalles Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 78 additions & 26 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ mixed_canonize(tn::AbstractAnsatz, args...; kwargs...) = mixed_canonize!(deepcop

canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)

"""
normalize!(ψ::AbstractAnsatz, at)

Normalize the state at a given [`Site`](@ref) or bond in a [`AbstractAnsatz`](@ref) Tensor Network.
"""
LinearAlgebra.normalize(ψ::AbstractAnsatz, site) = normalize!(copy(ψ), site)

"""
isisometry(tn::AbstractAnsatz, site; dir, kwargs...)

Expand Down Expand Up @@ -274,8 +281,9 @@ Truncate the dimension of the virtual `bond` of a [`NonCanonical`](@ref) Tensor
- `threshold`: The threshold to truncate the bond dimension.
- `maxdim`: The maximum bond dimension to keep.
- `compute_local_svd`: Whether to compute the local SVD of the bond. If `true`, it will contract the bond and perform a SVD to get the local singular values. Defaults to `true`.
- `normalize`: Whether to normalize the state at the bond after truncation. Defaults to `false`.
"""
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true)
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true, normalize=false)
virtualind = inds(tn; bond)

if compute_local_svd
Expand Down Expand Up @@ -305,26 +313,31 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
end

slice!(tn, virtualind, extent)
sliced_bond = tensors(tn; bond)

# Note: Inplace normalization of the inner arrays may be more efficient
normalize && replace!(tn, sliced_bond => sliced_bond ./ norm(tn))

return tn
end

function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim)
function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false)
# move orthogonality center to bond
mixed_canonize!(tn, bond)
return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true)

return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize)
end

"""
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true)

Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`.
**Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize)

recanonize && canonize!(tn)
canonize && canonize!(tn)

return tn
end
Expand Down Expand Up @@ -354,7 +367,7 @@ function expect(ψ::AbstractAnsatz, observables::AbstractVecOrTuple; bra=copy(ψ
end

"""
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, renormalize = false)
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, normalize = false)

Evolve (through time) a [`AbstractAnsatz`](@ref) Tensor Network with a `gate` operator.

Expand All @@ -367,16 +380,16 @@ Evolve (through time) a [`AbstractAnsatz`](@ref) Tensor Network with a `gate` op

- `threshold`: The threshold to truncate the bond dimension.
- `maxdim`: The maximum bond dimension to keep.
- `renormalize`: Whether to renormalize the state after truncation.
- `normalize`: Whether to normalize the state after truncation.

# Notes

- The gate must act on neighboring sites according to the [`Lattice`](@ref) of the Tensor Network.
- The gate must have the same number of inputs and outputs.
- Currently only the "Simple Update" algorithm is used and the gate must be a 1-site or 2-site operator.
"""
function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false)
return simple_update!(ψ, gate; threshold, maxdim, renormalize)
function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false, kwargs...)
return simple_update!(ψ, gate; threshold, maxdim, normalize, kwargs...)
end

# by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!`
Expand All @@ -387,11 +400,11 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth

if nlanes(gate) == 1
return simple_update_1site!(ψ, gate)
elseif nlanes(gate) == 2
return simple_update_2site!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
else
throw(ArgumentError("Only 1-site and 2-site gates are currently supported"))
end

@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
end

# TODO a lot of problems with merging... maybe we shouldn't merge manually
Expand Down Expand Up @@ -419,9 +432,15 @@ function simple_update_1site!(ψ::AbstractAnsatz, gate)
return contract!(ψ, contracting_index)
end

# TODO remove `renormalize` argument?
function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false)
@assert nlanes(gate) == 2 "Only 2-site gates are supported currently"
function simple_update_2site!(
::MixedCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false
)
return simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, normalize)
end

function simple_update_2site!(
::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false
)
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
Expand Down Expand Up @@ -455,16 +474,49 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth

# truncate virtual index
if any(!isnothing, (threshold, maxdim))
truncate!(ψ, bond; threshold, maxdim)
renormalize && normalize!(ψ, bond[1])
truncate!(ψ, collect(bond); threshold, maxdim, normalize)
end

return ψ
end

# TODO remove `renormalize` argument?
# TODO optimize correctly -> avoid recanonization + use lateral Λs
function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
return canonize!(ψ)
# TODO remove `normalize` argument?
function simple_update_2site!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, normalize=false, canonize=true)
# Contract the exterior Λ tensors
sitel, siter = extrema(lanes(gate))
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false)

simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, normalize=false)

# contract the updated tensors with the inverse of Λᵢ and Λᵢ₊₂, to get the new Γ tensors
U, Vt = tensors(ψ; at=sitel), tensors(ψ; at=siter)
Γᵢ₋₁ = if isnothing(Λᵢ₋₁)
U
else
contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=wrap_eps(eltype(U)))), inds(Λᵢ₋₁)); dims=())
end
Γᵢ = if isnothing(Λᵢ₊₁)
Vt
else
contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=wrap_eps(eltype(Vt)))), inds(Λᵢ₊₁)), Vt; dims=())
end

# Update the tensors in the tensor network
replace!(ψ, tensors(ψ; at=sitel) => Γᵢ₋₁)
replace!(ψ, tensors(ψ; at=siter) => Γᵢ)

if canonize
canonize!(ψ; normalize)
else
normalize && normalize!(ψ, collect((sitel, siter)))
end

return ψ
end
31 changes: 25 additions & 6 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr)
return ψ
end

function canonize!(ψ::AbstractMPO)
function canonize!(ψ::AbstractMPO; normalize=false)
Λ = Tensor[]

# right-to-left QR sweep, get right-canonical tensors
Expand All @@ -495,6 +495,7 @@ function canonize!(ψ::AbstractMPO)

# extract the singular values and contract them with the next tensor
Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1))))
normalize && (Λᵢ ./= norm(Λᵢ))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
Aᵢ₊₁ = tensors(ψ; at=Site(i + 1))
replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=()))
push!(Λ, Λᵢ)
Expand Down Expand Up @@ -541,19 +542,37 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center)
end

LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...)
LinearAlgebra.normalize!(ψ::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at)
LinearAlgebra.normalize!(ψ::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond)

# NOTE: Inplace normalization of the arrays should be faster, but currently lead to problems for `copy` TensorNetworks
function LinearAlgebra.normalize!(::NonCanonical, ψ::AbstractMPO; at=Site(nsites(ψ) ÷ 2))
tensor = tensors(ψ; at)
tensor ./= norm(ψ)
if at isa Site
tensor = tensors(ψ; at)
replace!(ψ, tensor => tensor ./ norm(ψ))
else
normalize!(mixed_canonize!(ψ, at))
end

return ψ
end

LinearAlgebra.normalize!(ψ::AbstractMPO, site::Site) = normalize!(mixed_canonize!(ψ, site); at=site)

function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=config.orthog_center)
mixed_canonize!(ψ, at)
normalize!(tensors(ψ; at), 2)
return ψ
end

# TODO function LinearAlgebra.normalize!(::Canonical, ψ::AbstractMPO) end
function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing)
if isnothing(bond) # Normalize all λ tensors
for i in 1:(nsites(ψ) - 1)
λ = tensors(ψ; between=(Site(i), Site(i + 1)))
replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1)))
end
else
λ = tensors(ψ; between=bond)
replace!(ψ, λ => λ ./ norm(λ))
end

return ψ
end
2 changes: 2 additions & 0 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ function Base.merge!(a::AbstractQuantum, b::AbstractQuantum; reset=true)
return a
end

LinearAlgebra.normalize(ψ::AbstractQuantum; kwargs...) = normalize!(copy(ψ); kwargs...)

function LinearAlgebra.norm(ψ::AbstractQuantum, p::Real=2; kwargs...)
p == 2 || throw(ArgumentError("only L2-norm is implemented yet"))
return LinearAlgebra.norm2(ψ; kwargs...)
Expand Down
91 changes: 76 additions & 15 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,33 @@ using LinearAlgebra
# If maxdim > size(spectrum), the bond dimension is not truncated
truncated = truncate(ψ, [site"2", site"3"]; maxdim=4)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
end

@testset "Canonical" begin
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
normalize!(ψ)
truncated = truncate(ψ, [site"2", site"3"]; maxdim=1, normalize=true)
@test norm(truncated) ≈ 1.0
end

@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=3)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 3

truncated = truncate(ψ, [site"2", site"3"]; maxdim=3, normalize=true)
@test norm(truncated) ≈ 1.0
end

@testset "Canonical" begin
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2, canonize=true, normalize=true)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
@test Tenet.check_form(truncated)
@test norm(truncated) ≈ 1.0

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2, canonize=false, normalize=true)
@test norm(truncated) ≈ 1.0
end
end

Expand All @@ -144,11 +156,42 @@ using LinearAlgebra
end

@testset "normalize!" begin
using LinearAlgebra: normalize!
using LinearAlgebra: normalize, normalize!

ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
normalize!(ψ, Site(3))
@test isapprox(norm(ψ), 1.0)
@testset "NonCanonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, Site(3))
@test norm(ψ) ≈ 1.0
end

@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

# Perturb the state to make it non-normalized
t = tensors(ψ; at=site"3")
replace!(ψ, t => Tensor(rand(size(t)...), inds(t)))

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, Site(3))
@test norm(ψ) ≈ 1.0
end

@testset "Canonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonize!(ψ)

normalized = normalize(ψ)
@test norm(normalized) ≈ 1.0

normalize!(ψ, (Site(3), Site(4)))
@test norm(ψ) ≈ 1.0
end
end

@testset "canonize_site!" begin
Expand Down Expand Up @@ -303,14 +346,32 @@ using LinearAlgebra
@test length(tensors(ϕ)) == 5
@test issetequal(size.(tensors(ϕ)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)])
@test isapprox(contract(ϕ), contract(ψ))

evolved = evolve!(normalize(ψ), gate; maxdim=1, normalize=true)
@test norm(evolved) ≈ 1.0
end

@testset "Canonical" begin
ψ = deepcopy(ψ)
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])
normalize!(ψ)
ϕ = deepcopy(ψ)

canonize!(ψ)
evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14)
@test isapprox(contract(evolved), contract(ψ))
@test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)])

evolved = evolve!(deepcopy(ψ), gate)
@test Tenet.check_form(evolved)
@test isapprox(contract(evolved), contract(ϕ)) # Identity gate should not change the state

# Ensure that the original MixedCanonical state evolves into the same state as the canonicalized one
@test contract(ψ) ≈ contract(evolve!(ϕ, gate; threshold=1e-14))

evolved = evolve!(deepcopy(ψ), gate; maxdim=1, normalize=true, canonize=true)
@test norm(evolved) ≈ 1.0
@test Tenet.check_form(evolved)

evolved = evolve!(deepcopy(ψ), gate; maxdim=1, normalize=true, canonize=false)
@test norm(evolved) ≈ 1.0
@test_throws ArgumentError Tenet.check_form(evolved)
end
end
end
Expand Down
Loading