From 7e6d332daa0e42acd266b09df7bbd727527c119d Mon Sep 17 00:00:00 2001 From: Todorbsc Date: Fri, 11 Oct 2024 12:29:48 +0200 Subject: [PATCH 1/3] Add verification of dimension mismatch in TensorNetwork constructor --- src/TensorNetwork.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index d5e59b3c6..a36acd677 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -45,7 +45,6 @@ struct TensorNetwork <: AbstractTensorNetwork tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors) indexmap = reduce(tensors; init=Dict{Symbol,Vector{Tensor}}()) do dict, tensor - # TODO check for inconsistent dimensions? for index in inds(tensor) # TODO use lambda? `Tensor[]` might be reused push!(get!(dict, index, Tensor[]), tensor) @@ -53,6 +52,12 @@ struct TensorNetwork <: AbstractTensorNetwork dict end + # Check for inconsistent index dimensions + for ind in keys(indexmap) + dims = map(tns -> size(tns)[findfirst(==(ind), tensormap[tns])], indexmap[ind]) + length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)")) + end + return new(indexmap, tensormap, CachedField{Vector{Tensor}}()) end end From a372735e018e36096e48d9c162e79af947df9fe4 Mon Sep 17 00:00:00 2001 From: Todorbsc Date: Fri, 11 Oct 2024 12:30:08 +0200 Subject: [PATCH 2/3] Remove testskip of inconsistent dimensions --- test/TensorNetwork_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 6b5448fbb..c30b97b31 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -23,7 +23,7 @@ @testset "TensorNetwork with tensors of different dimensions" begin tensor1 = Tensor(zeros(2, 2), (:i, :j)) tensor2 = Tensor(zeros(3, 3), (:j, :k)) - @test_skip @test_throws DimensionMismatch tn = TensorNetwork([tensor1, tensor2]) + @test_throws DimensionMismatch tn = TensorNetwork([tensor1, tensor2]) end end From 36bb209493ddb9e9028dd918966f749663a9e3a7 Mon Sep 17 00:00:00 2001 From: Todorbsc <145352308+Todorbsc@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:51:02 +0200 Subject: [PATCH 3/3] Rename tns as tensor (suggested by Sergio) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/TensorNetwork.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index a36acd677..ad8e601b5 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -54,7 +54,7 @@ struct TensorNetwork <: AbstractTensorNetwork # Check for inconsistent index dimensions for ind in keys(indexmap) - dims = map(tns -> size(tns)[findfirst(==(ind), tensormap[tns])], indexmap[ind]) + dims = map(tensor -> size(tensor, ind), indexmap[ind]) length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)")) end