Skip to content

Commit

Permalink
Merge branch 'master' into feature/enhance-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles authored Sep 15, 2023
2 parents 7d80d17 + ef96be2 commit bc53d77
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Bijections = "0.1"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
DeltaArrays = "0.1.1"
EinExprs = "0.5.2"
EinExprs = "0.5.5"
GraphMakie = "0.4,0.5"
Graphs = "1.7"
Makie = "0.18, 0.19"
Expand Down
1 change: 0 additions & 1 deletion docs/src/assets/youtube.css
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
margin-left: auto;
margin-right: auto;
text-align: center;
height: 315px;
}
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ A video of its presentation at JuliaCon 2023 can be seen here:

```@raw html
<div class="youtube-video">
<iframe width="560" height="315" src="https://www.youtube-nocookie.com/embed/8BHGtm6FRMk?si=bPXB6bPtK695HFIR" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
<iframe width="560" style="height='315'" src="https://www.youtube-nocookie.com/embed/8BHGtm6FRMk?si=bPXB6bPtK695HFIR" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
</div>
```

Expand Down
6 changes: 4 additions & 2 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork}
ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {A<:Ansatz,T<:TensorNetwork{A}}
TensorNetwork{A}(projector.tensors(dx.tensors); projector.metadata...)
function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetwork}
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
end
Expand Down
14 changes: 5 additions & 9 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,12 @@ The `kwargs` will be passed down to the [`einexpr`](@ref) function.
See also: [`einexpr`](@ref), [`contract!`](@ref).
"""
function contract(tn::TensorNetwork; kwargs...)
path = einexpr(tn; kwargs...)
function contract(tn::TensorNetwork; path = einexpr(tn))
# TODO does `first` work always?
length(path.args) == 0 && return select(tn, inds(path)) |> first

tn = copy(tn)

for indices in contractorder(path)
contract!(tn, indices)
end

tensors(tn) |> only
intermediates = map(subpath -> contract(tn; path = subpath), path.args)
contract(intermediates...; dims = suminds(path))
end

contract!(t::Tensor, tn::TensorNetwork; kwargs...) = contract!(tn, t; kwargs...)
Expand Down
4 changes: 4 additions & 0 deletions test/integration/ChainRules_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
Base.collect(tn::TensorNetwork) = tensors(tn)

@testset "TensorNetwork" begin
# TODO it crashes
# test_frule(TensorNetwork, Tensor[])
# test_rrule(TensorNetwork, Tensor[])

a = Tensor(rand(4, 2), (:i, :j))
b = Tensor(rand(2, 3), (:j, :k))

Expand Down

0 comments on commit bc53d77

Please sign in to comment.