Skip to content

Commit

Permalink
change tensor structure (part1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 12, 2024
1 parent 8747310 commit cdf19c4
Show file tree
Hide file tree
Showing 12 changed files with 1,053 additions and 1,094 deletions.
40 changes: 20 additions & 20 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,34 +184,34 @@ include("fusiontrees/fusiontrees.jl")
#-------------------------------------------
include("spaces/vectorspaces.jl")

# # Definitions and methods for tensors
# #-------------------------------------
# # general definitions
# Definitions and methods for tensors
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
include("tensors/tensortreeiterator.jl")
# include("tensors/tensortreeiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
include("tensors/linalg.jl")
include("tensors/vectorinterface.jl")
include("tensors/tensoroperations.jl")
include("tensors/indexmanipulations.jl")
include("tensors/truncation.jl")
include("tensors/factorizations.jl")
include("tensors/braidingtensor.jl")
# include("tensors/tensoroperations.jl")
# include("tensors/indexmanipulations.jl")
# include("tensors/truncation.jl")
# include("tensors/factorizations.jl")
# include("tensors/braidingtensor.jl")

# # Planar macros and related functionality
# #-----------------------------------------
@nospecialize
using Base.Meta: isexpr
include("planar/analyzers.jl")
include("planar/preprocessors.jl")
include("planar/postprocessors.jl")
include("planar/macros.jl")
@specialize
include("planar/planaroperations.jl")

# deprecations: to be removed in version 1.0 or sooner
include("auxiliary/deprecate.jl")
# @nospecialize
# using Base.Meta: isexpr
# include("planar/analyzers.jl")
# include("planar/preprocessors.jl")
# include("planar/postprocessors.jl")
# include("planar/macros.jl")
# @specialize
# include("planar/planaroperations.jl")

# # deprecations: to be removed in version 1.0 or sooner
# include("auxiliary/deprecate.jl")

# Extensions
# ----------
Expand Down
2 changes: 1 addition & 1 deletion src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ else
using Base: @constprop
end

const MatOrNumber{T<:Number} = Union{DenseMatrix{T},T}
const VecOrNumber{T<:Number} = Union{DenseVector{T},T}

"""
_interleave(a::NTuple{N}, b::NTuple{N}) -> NTuple{2N}
Expand Down
25 changes: 7 additions & 18 deletions src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,10 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare
using ..TensorKit: OrthogonalFactorizationAlgorithm,
QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar

# only defined in >v1.7
@static if VERSION < v"1.7-"
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)
_argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]
else
_argmax(f, domain) = argmax(f, domain)
end

# TODO: define for CuMatrix if we support this
function one!(A::DenseMatrix)
Threads.@threads for j in 1:size(A, 2)
@simd for i in 1:size(A, 1)
@inbounds A[i, j] = i == j
end
end
function one!(A::StridedMatrix)
A[:] .= 0
A[diagind(A)] .= 1
return A
end

Expand Down Expand Up @@ -291,12 +280,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where {
while j <= n
if DI[j] == 0
vr = view(VR, :, j)
s = conj(sign(_argmax(abs, vr)))
s = conj(sign(argmax(abs, vr)))
V[:, j] .= s .* vr
else
vr = view(VR, :, j)
vi = view(VR, :, j + 1)
s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component
s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component
V[:, j] .= s .* (vr .+ im .* vi)
V[:, j + 1] .= s .* (vr .- im .* vi)
j += 1
Expand All @@ -314,7 +303,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true,
A)[[2, 4]]
for j in 1:n
v = view(V, :, j)
s = conj(sign(_argmax(abs, v)))
s = conj(sign(argmax(abs, v)))
v .*= s
end
return D, V
Expand All @@ -326,7 +315,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat}
D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
for j in 1:n
v = view(V, :, j)
s = conj(sign(_argmax(abs, v)))
s = conj(sign(argmax(abs, v)))
v .*= s
end
return D, V
Expand Down
108 changes: 106 additions & 2 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ function blocksectors(W::HomSpace)
if N₁ == 0 || N₂ == 0
return (one(I),)
elseif N₂ <= N₁
return filter!(c -> hasblock(codom, c), collect(blocksectors(dom)))
return sort!(filter!(c -> hasblock(codom, c), collect(blocksectors(dom))))
else
return filter!(c -> hasblock(dom, c), collect(blocksectors(codom)))
return sort!(filter!(c -> hasblock(dom, c), collect(blocksectors(codom))))
end
end

Expand Down Expand Up @@ -134,3 +134,107 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
domain(W) == codomain(V) || throw(SpaceMismatch("$(domain(W))$(codomain(V))"))
return HomSpace(codomain(W), domain(V))
end

# Block and fusion tree ranges: structure information for building tensors
#--------------------------------------------------------------------------
struct TensorStructure{I,F₁,F₂}
totaldim::Int
blockstructure::SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}
fusiontreelist::Vector{Tuple{F₁,F₂}}
fusiontreeranges::Vector{Tuple{UnitRange{Int},UnitRange{Int}}}
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
end

abstract type CacheStyle end
struct NoCache <: CacheStyle end
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
struct GlobalCache <: CacheStyle end

function CacheStyle(I::Type{<:Sector})
return GlobalCache()
# if FusionStyle(I) === UniqueFusion()
# return TaskLocalCache{SectorDict{I,Any}}()
# else
# return GlobalCache()
# end
end

tensorstructure(W::HomSpace) = tensorstructure(W, CacheStyle(sectortype(W)))

function tensorstructure(W::HomSpace, ::NoCache)
codom = codomain(W)
dom = domain(W)
N₁ = length(codom)
N₂ = length(dom)
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)

blockstructure = SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}()
fusiontreelist = Vector{Tuple{F₁,F₂}}()
fusiontreeranges = Vector{Tuple{UnitRange{Int},UnitRange{Int}}}()
outer_offset = 0
for c in blocksectors(W)
inner_offset₂ = 0
inner_offset₁ = 0
for f₂ in fusiontrees(dom, c)
s₂ = f₂.uncoupled
d₂ = dim(dom, s₂)
r₂ = (inner_offset₂ + 1):(inner_offset₂ + d₂)
inner_offset₂ = last(r₂)
# TODO: # now we run the code below for every f₂; should we do this separately
inner_offset₁ = 0 # reset here to avoid multiple counting
for f₁ in fusiontrees(codom, c)
s₁ = f₁.uncoupled
d₁ = dim(codom, s₁)
r₁ = (inner_offset₁ + 1):(inner_offset₁ + d₁)
inner_offset₁ = last(r₁)
push!(fusiontreelist, (f₁, f₂))
push!(fusiontreeranges, (r₁, r₂))
end
end
blocksize = (inner_offset₁, inner_offset₂)
blocklength = blocksize[1] * blocksize[2]
blockrange = (outer_offset + 1):(outer_offset + blocklength)
outer_offset = last(blockrange)
blockstructure[c] = (blocksize, blockrange)
end
fusiontreeindices = sizehint!(FusionTreeDict{Tuple{F₁,F₂},Int}(), length(fusiontreelist))
for i = 1:length(fusiontreelist)
fusiontreeindices[fusiontreelist[i]] = i
end
totaldim = outer_offset
structure = TensorStructure(totaldim, blockstructure, fusiontreelist, fusiontreeranges, fusiontreeindices)
return structure
end

function tensorstructure(W::HomSpace, ::TaskLocalCache{D}) where {D}
cache::D = get!(task_local_storage(), :_local_tensorstructure_cache) do
return D()
end
N₁ = length(codomain(W))
N₂ = length(domain(W))
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::TensorStructure{I,F₁,F₂} = get!(cache, W) do
tensorstructure(W, NoCache())
end
return structure
end

const GLOBAL_TENSORSTRUCTURE_CACHE = LRU{Any,Any}(; maxsize=10^4)
# 10^4 different tensor spaces should be enough for most purposes
function tensorstructure(W::HomSpace, ::GlobalCache)
cache = GLOBAL_TENSORSTRUCTURE_CACHE
N₁ = length(codomain(W))
N₂ = length(domain(W))
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::TensorStructure{I,F₁,F₂} = get!(cache, W) do
return tensorstructure(W, NoCache())
end
return structure
end

28 changes: 14 additions & 14 deletions src/spaces/productspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,6 @@ function Base.axes(P::ProductSpace{<:ElementarySpace,N},
return map(axes, P.spaces, sectors)
end

"""
fusiontrees(P::ProductSpace, blocksector::Sector)
Return an iterator over all fusion trees that can be formed by fusing the sectors present
in the different spaces that make up the `ProductSpace` instance, resulting in the coupled
sector `blocksector`.
"""
function fusiontrees(P::ProductSpace{S,N}, blocksector::I) where {S,N,I}
I == sectortype(S) || throw(SectorMismatch())
uncoupled = map(sectors, P.spaces)
isdualflags = map(isdual, P.spaces)
return FusionTreeIterator(uncoupled, blocksector, isdualflags)
end

"""
blocksectors(P::ProductSpace)
Expand Down Expand Up @@ -179,6 +165,20 @@ function blocksectors(P::ProductSpace{S,N}) where {S,N}
return bs
end

"""
fusiontrees(P::ProductSpace, blocksector::Sector)
Return an iterator over all fusion trees that can be formed by fusing the sectors present
in the different spaces that make up the `ProductSpace` instance into the coupled sector
`blocksector`.
"""
function fusiontrees(P::ProductSpace{S,N}, blocksector::I) where {S,N,I}
I == sectortype(S) || throw(SectorMismatch())
uncoupled = map(sectors, P.spaces)
isdualflags = map(isdual, P.spaces)
return FusionTreeIterator(uncoupled, blocksector, isdualflags)
end

"""
hasblock(P::ProductSpace, c::Sector)
Expand Down
Loading

0 comments on commit cdf19c4

Please sign in to comment.