Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sectorscalartype #146

Merged
merged 4 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TensorKitSectors/src/TensorKitSectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export Sector, Group, AbstractIrrep
export Irrep

export Nsymbol, Fsymbol, Rsymbol, Asymbol, Bsymbol
export sectorscalartype
export dim, sqrtdim, invsqrtdim, frobeniusschur, twist, fusiontensor, dual
export otimes, deligneproduct, times
export FusionStyle, UniqueFusion, MultipleFusion, SimpleFusion, GenericFusion,
Expand Down
52 changes: 30 additions & 22 deletions TensorKitSectors/src/irreps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,22 +197,26 @@ findindex(::SectorValues{SU2Irrep}, s::SU2Irrep) = twice(s.j) + 1
dim(s::SU2Irrep) = twice(s.j) + 1

FusionStyle(::Type{SU2Irrep}) = SimpleFusion()
sectorscalartype(::Type{SU2Irrep}) = Float64
Base.isreal(::Type{SU2Irrep}) = true

Nsymbol(sa::SU2Irrep, sb::SU2Irrep, sc::SU2Irrep) = WignerSymbols.δ(sa.j, sb.j, sc.j)

function Fsymbol(s1::SU2Irrep, s2::SU2Irrep, s3::SU2Irrep,
s4::SU2Irrep, s5::SU2Irrep, s6::SU2Irrep)
if all(==(_su2one), (s1, s2, s3, s4, s5, s6))
return 1.0
else
return sqrtdim(s5) * sqrtdim(s6) *
WignerSymbols.racahW(Float64, s1.j, s2.j,
WignerSymbols.racahW(sectorscalartype(SU2Irrep), s1.j, s2.j,
s4.j, s3.j, s5.j, s6.j)
end
end

function Rsymbol(sa::SU2Irrep, sb::SU2Irrep, sc::SU2Irrep)
Nsymbol(sa, sb, sc) || return 0.0
return iseven(convert(Int, sa.j + sb.j - sc.j)) ? 1.0 : -1.0
Nsymbol(sa, sb, sc) || return zero(sectorscalartype(SU2Irrep))
return iseven(convert(Int, sa.j + sb.j - sc.j)) ? one(sectorscalartype(SU2Irrep)) :
-one(sectorscalartype(SU2Irrep))
end

function fusiontensor(a::SU2Irrep, b::SU2Irrep, c::SU2Irrep)
Expand Down Expand Up @@ -341,59 +345,62 @@ Base.eltype(::Type{CU1ProdIterator}) = CU1Irrep
dim(c::CU1Irrep) = ifelse(c.j == zero(HalfInt), 1, 2)

FusionStyle(::Type{CU1Irrep}) = SimpleFusion()
sectorscalartype(::Type{CU1Irrep}) = Float64
Base.isreal(::Type{CU1Irrep}) = true

function Nsymbol(a::CU1Irrep, b::CU1Irrep, c::CU1Irrep)
return ifelse(c.s == 0, (a.j == b.j) & ((a.s == b.s == 2) | (a.s == b.s)),
ifelse(c.s == 1, (a.j == b.j) & ((a.s == b.s == 2) | (a.s != b.s)),
(c.j == a.j + b.j) | (c.j == abs(a.j - b.j))))
end

function Fsymbol(a::CU1Irrep, b::CU1Irrep, c::CU1Irrep,
d::CU1Irrep, e::CU1Irrep, f::CU1Irrep)
Nabe = convert(Int, Nsymbol(a, b, e))
Necd = convert(Int, Nsymbol(e, c, d))
Nbcf = convert(Int, Nsymbol(b, c, f))
Nafd = convert(Int, Nsymbol(a, f, d))

Nabe * Necd * Nbcf * Nafd == 0 && return 0.0
T = sectorscalartype(CU1Irrep)
Nabe * Necd * Nbcf * Nafd == 0 && return zero(T)

op = CU1Irrep(0, 0)
om = CU1Irrep(0, 1)

if a == op || b == op || c == op
return 1.0
return one(T)
end
if (a == b == om) || (a == c == om) || (b == c == om)
return 1.0
return one(T)
end
if a == om
if d.j == zero(HalfInt)
return 1.0
return one(T)
else
return (d.j == c.j - b.j) ? -1.0 : 1.0
return (d.j == c.j - b.j) ? -one(T) : one(T)
end
end
if b == om
return (d.j == abs(a.j - c.j)) ? -1.0 : 1.0
return (d.j == abs(a.j - c.j)) ? -one(T) : one(T)
end
if c == om
return (d.j == a.j - b.j) ? -1.0 : 1.0
return (d.j == a.j - b.j) ? -one(T) : one(T)
end
# from here on, a, b, c are neither 0+ or 0-
s = sqrt(2) / 2
s = T(sqrt(2) / 2)
if a == b == c
if d == a
if e.j == 0
if f.j == 0
return f.s == 1 ? -0.5 : 0.5
return f.s == 1 ? T(-1 // 2) : T(1 // 2)
else
return e.s == 1 ? -s : s
end
else
return f.j == 0 ? s : 0.0
return f.j == 0 ? s : zero(T)
end
else
return 1.0
return one(T)
end
end
if a == b # != c
Expand All @@ -404,7 +411,7 @@ function Fsymbol(a::CU1Irrep, b::CU1Irrep, c::CU1Irrep,
return s
end
else
return 1.0
return one(T)
end
end
if b == c
Expand All @@ -415,27 +422,28 @@ function Fsymbol(a::CU1Irrep, b::CU1Irrep, c::CU1Irrep,
return f.s == 1 ? -s : s
end
else
return 1.0
return one(T)
end
end
if a == c
if d == b
if e.j == f.j
return 0.0
return zero(T)
else
return 1.0
return one(T)
end
else
return d.s == 1 ? -1.0 : 1.0
return d.s == 1 ? -one(T) : one(T)
end
end
if d == om
return b.j == a.j + c.j ? -1.0 : 1.0
return b.j == a.j + c.j ? -one(T) : one(T)
end
return 1.0
return one(T)
end

function Rsymbol(a::CU1Irrep, b::CU1Irrep, c::CU1Irrep)
R = convert(Float64, Nsymbol(a, b, c))
R = convert(sectorscalartype(CU1Irrep), Nsymbol(a, b, c))
return c.s == 1 && a.j > 0 ? -R : R
end

Expand Down
25 changes: 16 additions & 9 deletions TensorKitSectors/src/sectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,28 @@ Return the conjugate label `conj(a)`.
dual(a::Sector) = conj(a)

"""
isreal(::Type{<:Sector}) -> Bool
sectorscalartype(I::Type{<:Sector}) -> Type

Return whether the topological data (Fsymbol, Rsymbol) of the sector is real or not (in
which case it is complex).
Return the scalar type of the topological data (Fsymbol, Rsymbol) of the sector `I`.
"""
function Base.isreal(I::Type{<:Sector})
u = one(I)
if BraidingStyle(I) isa HasBraiding
return (eltype(Fsymbol(u, u, u, u, u, u)) <: Real) &&
(eltype(Rsymbol(u, u, u)) <: Real)
function sectorscalartype(::Type{I}) where {I<:Sector}
if BraidingStyle(I) isa NoBraiding
return eltype(Core.Compiler.return_type(Fsymbol, NTuple{6,I}))
else
return (eltype(Fsymbol(u, u, u, u, u, u)) <: Real)
Feltype = eltype(Core.Compiler.return_type(Fsymbol, NTuple{6,I}))
Reltype = eltype(Core.Compiler.return_type(Rsymbol, NTuple{3,I}))
return Base.promote_op(*, Feltype, Reltype)
end
end

"""
isreal(::Type{<:Sector}) -> Bool

Return whether the topological data (Fsymbol, Rsymbol) of the sector is real or not (in
which case it is complex).
"""
Base.isreal(I::Type{<:Sector}) = sectorscalartype(I) <: Real

# FusionStyle: the most important aspect of Sector
#---------------------------------------------
"""
Expand Down
49 changes: 16 additions & 33 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# this actually removes uncoupled line i, which should be trivial
(f₁.uncoupled[i] == f₂.coupled && !f₁.isdual[i]) ||
throw(SectorMismatch("cannot connect $(f₂.uncoupled) to $(f₁.uncoupled[i])"))
coeff = Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1, 1, 1]
coeff = one(sectorscalartype(I))

uncoupled = TupleTools.deleteat(f₁.uncoupled, i)
coupled = f₁.coupled
Expand All @@ -39,7 +39,7 @@
# identity operation
(f₁.uncoupled[i] == f₂.coupled && !f₁.isdual[i]) ||
throw(SectorMismatch("cannot connect $(f₂.uncoupled) to $(f₁.uncoupled[i])"))
coeff = Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1, 1, 1]
coeff = one(sectorscalartype(I))
isdual′ = TupleTools.setindex(f₁.isdual, f₂.isdual[1], i)
f = FusionTree{I}(f₁.uncoupled, f₁.coupled, isdual′, f₁.innerlines, f₁.vertices)
return fusiontreedict(I)(f => coeff)
Expand All @@ -59,7 +59,7 @@
isdual′ = (isdualb, isdualc, tail(isdual)...)
inner′ = (uncoupled[1], inner...)
vertices′ = (f₂.vertices..., f₁.vertices...)
coeff = Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1, 1, 1]
coeff = one(sectorscalartype(I))
f′ = FusionTree(uncoupled′, coupled, isdual′, inner′, vertices′)
return fusiontreedict(I)(f′ => coeff)
end
Expand Down Expand Up @@ -110,8 +110,8 @@
F = fusiontreetype(I, N₁ + N₂ - 1)
(f₁.uncoupled[i] == f₂.coupled && !f₁.isdual[i]) ||
throw(SectorMismatch("cannot connect $(f₂.uncoupled) to $(f₁.uncoupled[i])"))
coeff = Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1]
T = typeof(coeff)
T = sectorscalartype(I)
coeff = one(T)
if length(f₁) == 1
return fusiontreedict(I){F,T}(f₂ => coeff)
end
Expand Down Expand Up @@ -457,8 +457,8 @@
# precompute the parameters of the return type
F₁ = fusiontreetype(I, N)
F₂ = fusiontreetype(I, N₁ + N₂ - N)
coeff = @inbounds Fsymbol(one(I), one(I), one(I), one(I), one(I), one(I))[1, 1, 1, 1]
T = typeof(coeff)
T = sectorscalartype(I)
coeff = one(T)
if N == N₁
return fusiontreedict(I){Tuple{F₁,F₂},T}((f₁, f₂) => coeff)
else
Expand Down Expand Up @@ -500,8 +500,7 @@
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
@assert iscyclicpermutation(p)
if usetransposecache[]
u = one(I)
T = eltype(Fsymbol(u, u, u, u, u, u))
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = fusiontreedict(I){Tuple{F₁,F₂},T}
Expand Down Expand Up @@ -587,8 +586,7 @@
map(l -> l - count(l .> q′), TupleTools.getindices(linearindex, p2)))
end

u = one(I)
T = typeof(Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1])
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
newtrees = FusionTreeDict{Tuple{F₁,F₂},T}()
Expand All @@ -615,8 +613,7 @@
"""
function planar_trace(f::FusionTree{I,N},
q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {I<:Sector,N,N₃}
u = one(I)
T = typeof(Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1])
T = sectorscalartype(I)
F = fusiontreetype(I, N - 2 * N₃)
newtrees = FusionTreeDict{F,T}()
N₃ === 0 && return push!(newtrees, f => one(T))
Expand Down Expand Up @@ -673,8 +670,7 @@
i < N || f.coupled == one(I) ||
throw(ArgumentError("Cannot trace outputs i=$N and 1 of fusion tree that couples to non-trivial sector"))

u = one(I)
T = typeof(Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1])
T = sectorscalartype(I)
F = fusiontreetype(I, N - 2)
newtrees = FusionTreeDict{F,T}()

Expand Down Expand Up @@ -786,12 +782,7 @@
inner_extended = (uncoupled[1], inner..., coupled′)
vertices = f.vertices
u = one(I)

if BraidingStyle(I) isa NoBraiding
oneT = Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1]
else
oneT = Rsymbol(u, u, u)[1, 1] * Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1]
end
oneT = one(sectorscalartype(I))

if u in (uncoupled[i], uncoupled[i + 1])
# braiding with trivial sector: simple and always possible
Expand Down Expand Up @@ -925,7 +916,7 @@
p::NTuple{N,Int}) where {I<:Sector,N}
TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p"))
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
coeff = Rsymbol(one(I), one(I), one(I))
coeff = one(sectorscalartype(I))
for i in 1:N
for j in 1:(i - 1)
if p[j] > p[i]
Expand All @@ -940,10 +931,7 @@
f′ = FusionTree{I}(uncoupled′, coupled′, isdual′)
return fusiontreedict(I)(f′ => coeff)
else
u = one(I)
T = BraidingStyle(I) isa NoBraiding ?
typeof(Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1]) :
typeof(Rsymbol(u, u, u)[1, 1] * Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1])
T = sectorscalartype(I)
coeff = one(T)
trees = FusionTreeDict(f => coeff)
newtrees = empty(trees)
Expand Down Expand Up @@ -1009,8 +997,7 @@
if FusionStyle(f₁) isa UniqueFusion &&
BraidingStyle(f₁) isa SymmetricBraiding
if usebraidcache_abelian[]
u = one(I)
T = Int
T = Int # do we hardcode this ?

Check warning on line 1000 in src/fusiontrees/manipulations.jl

View check run for this annotation

Codecov / codecov/patch

src/fusiontrees/manipulations.jl#L1000

Added line #L1000 was not covered by tests
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = SingletonDict{Tuple{F₁,F₂},T}
Expand All @@ -1020,11 +1007,7 @@
end
else
if usebraidcache_nonabelian[]
u = one(I)
T = BraidingStyle(I) isa NoBraiding ?
typeof(Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1]) :
typeof(sqrtdim(u) * Fsymbol(u, u, u, u, u, u)[1, 1, 1, 1] *
Rsymbol(u, u, u)[1, 1])
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = FusionTreeDict{Tuple{F₁,F₂},T}
Expand Down