Skip to content

Commit

Permalink
Rename ColumnReduction to Truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 28, 2024
1 parent 6c134f9 commit 122715b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions docs/src/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ fig #hide
### Column reduction

```@docs
Tenet.ColumnReduction
Tenet.Truncate
```

```@example plot
Expand All @@ -194,7 +194,7 @@ B = Tensor(rand(3, 3), (:j, :l)) #hide
C = Tensor(rand(3, 3), (:l, :m)) #hide
tn = TensorNetwork([A, B, C]) #hide
reduced = transform(tn, Tenet.ColumnReduction) #hide
reduced = transform(tn, Tenet.Truncate) #hide
smooth_annotation!( #hide
fig[1, 1]; #hide
Expand Down
64 changes: 32 additions & 32 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,38 @@ function transform!(tn::TensorNetwork, config::ContractSimplification)
return tn
end

"""
Truncate <: Transformation
Truncate the dimension of a `Tensor` in a [`TensorNetwork`](@ref) when it contains columns with all elements smaller than `atol`.
# Keyword Arguments
- `atol` Absolute tolerance. Defaults to `1e-12`.
- `skip` List of indices to skip. Defaults to `[]`.
"""
Base.@kwdef struct Truncate <: Transformation
atol::Float64 = 1e-12
skip::Vector{Symbol} = Symbol[]
end

function transform!(tn::TensorNetwork, config::Truncate)
skip_inds = isempty(config.skip) ? inds(tn; set=:open) : config.skip

for tensor in tensors(tn)
for (dim, index) in enumerate(inds(tensor))
index skip_inds && continue

zeroslices = iszero.(eachslice(tensor; dims=dim))
any(zeroslices) || continue

slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices))
end
end

return tn
end

"""
DiagonalReduction <: Transformation
Expand Down Expand Up @@ -233,38 +265,6 @@ function transform!(tn::TensorNetwork, config::AntiDiagonalGauging)
return tn
end

"""
ColumnReduction <: Transformation
Truncate the dimension of a `Tensor` in a [`TensorNetwork`](@ref) when it contains columns with all elements smaller than `atol`.
# Keyword Arguments
- `atol` Absolute tolerance. Defaults to `1e-12`.
- `skip` List of indices to skip. Defaults to `[]`.
"""
Base.@kwdef struct ColumnReduction <: Transformation
atol::Float64 = 1e-12
skip::Vector{Symbol} = Symbol[]
end

function transform!(tn::TensorNetwork, config::ColumnReduction)
skip_inds = isempty(config.skip) ? inds(tn; set=:open) : config.skip

for tensor in tensors(tn)
for (dim, index) in enumerate(inds(tensor))
index skip_inds && continue

zeroslices = iszero.(eachslice(tensor; dims=dim))
any(zeroslices) || continue

slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices))
end
end

return tn
end

"""
SplitSimplification <: Transformation
Expand Down
8 changes: 4 additions & 4 deletions test/Transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@
@test contract(gauged) contract(tn)
end

@testset "ColumnReduction" begin
using Tenet: ColumnReduction
@testset "Truncate" begin
using Tenet: Truncate

@testset "range" begin
data = rand(3, 3, 3)
Expand All @@ -199,7 +199,7 @@
C = Tensor(rand(3, 3), (:j, :m))

tn = TensorNetwork([A, B, C])
reduced = transform(tn, ColumnReduction)
reduced = transform(tn, Truncate)

@test :j inds(reduced)
@test contract(reduced) contract(tn)
Expand All @@ -214,7 +214,7 @@
C = Tensor(rand(3, 3), (:j, :m))

tn = TensorNetwork([A, B, C])
reduced = transform(tn, ColumnReduction)
reduced = transform(tn, Truncate)

@test size(reduced, :j) == 2
@test contract(reduced) contract(tn)
Expand Down

0 comments on commit 122715b

Please sign in to comment.