diff --git a/docs/src/transformations.md b/docs/src/transformations.md index c84454534..34a280509 100644 --- a/docs/src/transformations.md +++ b/docs/src/transformations.md @@ -179,7 +179,7 @@ fig #hide ### Column reduction ```@docs -Tenet.ColumnReduction +Tenet.Truncate ``` ```@example plot @@ -194,7 +194,7 @@ B = Tensor(rand(3, 3), (:j, :l)) #hide C = Tensor(rand(3, 3), (:l, :m)) #hide tn = TensorNetwork([A, B, C]) #hide -reduced = transform(tn, Tenet.ColumnReduction) #hide +reduced = transform(tn, Tenet.Truncate) #hide smooth_annotation!( #hide fig[1, 1]; #hide diff --git a/src/Transformations.jl b/src/Transformations.jl index 45e46d528..7056906c3 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -146,6 +146,38 @@ function transform!(tn::TensorNetwork, config::ContractSimplification) return tn end +""" + Truncate <: Transformation + +Truncate the dimension of a `Tensor` in a [`TensorNetwork`](@ref) when it contains columns with all elements smaller than `atol`. + +# Keyword Arguments + + - `atol` Absolute tolerance. Defaults to `1e-12`. + - `skip` List of indices to skip. Defaults to `[]`. +""" +Base.@kwdef struct Truncate <: Transformation + atol::Float64 = 1e-12 + skip::Vector{Symbol} = Symbol[] +end + +function transform!(tn::TensorNetwork, config::Truncate) + skip_inds = isempty(config.skip) ? inds(tn; set=:open) : config.skip + + for tensor in tensors(tn) + for (dim, index) in enumerate(inds(tensor)) + index ∈ skip_inds && continue + + zeroslices = iszero.(eachslice(tensor; dims=dim)) + any(zeroslices) || continue + + slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices)) + end + end + + return tn +end + """ DiagonalReduction <: Transformation @@ -233,38 +265,6 @@ function transform!(tn::TensorNetwork, config::AntiDiagonalGauging) return tn end -""" - ColumnReduction <: Transformation - -Truncate the dimension of a `Tensor` in a [`TensorNetwork`](@ref) when it contains columns with all elements smaller than `atol`. - -# Keyword Arguments - - - `atol` Absolute tolerance. Defaults to `1e-12`. - - `skip` List of indices to skip. Defaults to `[]`. -""" -Base.@kwdef struct ColumnReduction <: Transformation - atol::Float64 = 1e-12 - skip::Vector{Symbol} = Symbol[] -end - -function transform!(tn::TensorNetwork, config::ColumnReduction) - skip_inds = isempty(config.skip) ? inds(tn; set=:open) : config.skip - - for tensor in tensors(tn) - for (dim, index) in enumerate(inds(tensor)) - index ∈ skip_inds && continue - - zeroslices = iszero.(eachslice(tensor; dims=dim)) - any(zeroslices) || continue - - slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices)) - end - end - - return tn -end - """ SplitSimplification <: Transformation diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index 5dfb90c37..dc5fe0d50 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -187,8 +187,8 @@ @test contract(gauged) ≈ contract(tn) end - @testset "ColumnReduction" begin - using Tenet: ColumnReduction + @testset "Truncate" begin + using Tenet: Truncate @testset "range" begin data = rand(3, 3, 3) @@ -199,7 +199,7 @@ C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, ColumnReduction) + reduced = transform(tn, Truncate) @test :j ∉ inds(reduced) @test contract(reduced) ≈ contract(tn) @@ -214,7 +214,7 @@ C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, ColumnReduction) + reduced = transform(tn, Truncate) @test size(reduced, :j) == 2 @test contract(reduced) ≈ contract(tn)