From 00aa2a3adeff8009ed22c821bd0578ba3b2bd915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 12 Sep 2023 12:50:41 +0200 Subject: [PATCH] Fix autodiff on `contract`ion of `Tensor`s --- ext/TenetChainRulesCoreExt.jl | 5 +++++ src/Numerics.jl | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index 3b5f66730..58f4b3eff 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -20,6 +20,11 @@ function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) return T(data, inds; meta...), Tensor_pullback end +@non_differentiable copy(tn::TensorNetwork) + +# NOTE fix problem with vector generator in `contract` +@non_differentiable Tenet.__omeinsum_sym2str(x) + # WARN type-piracy @non_differentiable setdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) @non_differentiable union(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) diff --git a/src/Numerics.jl b/src/Numerics.jl index 8a30dc73b..1d75ad7cc 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -29,19 +29,27 @@ for op in [ @eval Base.$op(a::Tensor{A,0}, b::Tensor{B,0}) where {A,B} = broadcast($op, a, b) end +# NOTE used for marking non-differentiability +# NOTE use `String[...]` code instead of `map` or broadcasting to set eltype in empty cases +__omeinsum_sym2str(x) = String[string(i) for i in x] + """ contract(a::Tensor[, b::Tensor, dims=nonunique([inds(a)..., inds(b)...])]) Perform tensor contraction operation. """ function contract(a::Tensor, b::Tensor; dims = (∩(inds(a), inds(b)))) - ia = inds(a) - ib = inds(b) + ia = inds(a) |> collect + ib = inds(b) |> collect i = ∩(dims, ia, ib) - ic = tuple(setdiff(ia ∪ ib, i isa Base.AbstractVecOrTuple ? i : (i,))...) + ic = setdiff(ia ∪ ib, i isa Base.AbstractVecOrTuple ? i : (i,))::Vector{Symbol} + + _ia = __omeinsum_sym2str(ia) + _ib = __omeinsum_sym2str(ib) + _ic = __omeinsum_sym2str(ic) - data = EinCode((String.(ia), String.(ib)), String.(ic))(parent(a), parent(b)) + data = EinCode((_ia, _ib), _ic)(parent(a), parent(b)) # TODO merge metadata? return Tensor(data, ic)