From fbbe84d004f9d6b146b18a67bb96ecbfa91d9422 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 29 Jul 2024 20:31:56 +0200 Subject: [PATCH 01/10] First draft --- GNNLux/src/GNNLux.jl | 4 +++- GNNLux/src/layers/temporalconv.jl | 26 ++++++++++++++++++++++++++ GNNlib/src/GNNlib.jl | 2 +- GNNlib/src/layers/temporalconv.jl | 9 +++++++++ src/layers/temporalconv.jl | 4 +--- 5 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 GNNLux/src/layers/temporalconv.jl diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ecac67b5a..561843477 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -34,7 +34,9 @@ export AGNNConv, # SGConv, # TAGConv, # TransformerConv - + +include("layers/temporalconv.jl") +export TGCNCell end #module \ No newline at end of file diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl new file mode 100644 index 000000000..01788bcf7 --- /dev/null +++ b/GNNLux/src/layers/temporalconv.jl @@ -0,0 +1,26 @@ +@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)} + in_dims::Int + out_dims::Int + conv + gru +end + +function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) + in_dims, out_dims = ch + conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true) + gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) + return TGCNCell(in_dims, out_dims, conv, gru) +end + +LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) + +function (l::TGCNCell)(h, g, x) + conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv)) + gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) + m = (; conv, gru) + return GNNlib.tgcn_conv(m, h, g, x) +end + +function Base.show(io::IO, tgcn::TGCNCell) + print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") +end \ No newline at end of file diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl index a84253776..3ca4acc9b 100644 --- a/GNNlib/src/GNNlib.jl +++ b/GNNlib/src/GNNlib.jl @@ -61,7 +61,7 @@ export agnn_conv, transformer_conv include("layers/temporalconv.jl") -export a3tgcn_conv +export tgcn_conv include("layers/pool.jl") export global_pool, diff --git a/GNNlib/src/layers/temporalconv.jl b/GNNlib/src/layers/temporalconv.jl index 8cff3f033..13198cdee 100644 --- a/GNNlib/src/layers/temporalconv.jl +++ b/GNNlib/src/layers/temporalconv.jl @@ -1,3 +1,12 @@ +####################### TGCN ###################################### + +function tgcn_conv(l, h, g::GNNGraph, x::AbstractArray) + x̃ = l.conv(g, x) + h, x̃ = l.gru(h, x̃) + return h, x̃ +end + + function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) h = a3tgcn.tgcn(g, x) e = a3tgcn.dense1(h) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 44688cea4..c8c74b907 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -35,9 +35,7 @@ function TGCNCell(ch::Pair{Int, Int}; end function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) - x̃ = tgcn.conv(g, x) - h, x̃ = tgcn.gru(h, x̃) - return h, x̃ + return GNNlib.tgcn_conv(tgcn, h, g, x) end function Base.show(io::IO, tgcn::TGCNCell) From dc39f81f004f0616dfcf6619490e5db47e815dff Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 8 Aug 2024 13:50:18 +0200 Subject: [PATCH 02/10] Fix signature --- GNNLux/src/layers/temporalconv.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 01788bcf7..3bcf078c2 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -1,3 +1,7 @@ +@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} + cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer} +end + @concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)} in_dims::Int out_dims::Int @@ -14,7 +18,7 @@ end LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) -function (l::TGCNCell)(h, g, x) +function (l::TGCNCell)(h, g, x, ps, st) conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv)) gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) m = (; conv, gru) From 09338181c749007f1bdbd2ba68606ae764e566f1 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Tue, 13 Aug 2024 14:41:59 +0200 Subject: [PATCH 03/10] Improvement --- GNNLux/src/layers/temporalconv.jl | 34 ++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 3bcf078c2..9cf7fe281 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -1,7 +1,26 @@ -@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer} end +function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell) + return (cell=initialstates(rng, r.cell), carry=nothing) +end + +function (r::StatefulRecurrentCell)(g, x, ps, st::NamedTuple) + (out, carry), st_ = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry) + return out, (; cell=st_, carry) +end + +function applyrecurrentcell(l, g, x, ps, st, carry) + return Lux.apply(l, g, (x, carry), ps, st) +end + +function applyrecurrentcell(l, g, x, ps, st, ::Nothing) + return Lux.apply(l, g, x, ps, st) +end + +LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st) + @concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)} in_dims::Int out_dims::Int @@ -18,13 +37,18 @@ end LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) -function (l::TGCNCell)(h, g, x, ps, st) +function (l::TGCNCell)(g, x, h, ps, st) conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv)) gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - m = (; conv, gru) - return GNNlib.tgcn_conv(m, h, g, x) + #m = (; conv, gru) + + x̃, stconv = l.conv(g, x, ps.conv, st.conv) + (h, (h,)), st = l.gru((x̃,(h,)), ps.gru,st.gru) + return (h, (h,)), st end function Base.show(io::IO, tgcn::TGCNCell) print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") -end \ No newline at end of file +end + +tgcn = StatefulRecurrentCell(TGCNCell(1 =>3)) \ No newline at end of file From c5d90ee201cf80d0e974c80cfbf243636e3192af Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 15 Aug 2024 15:29:07 +0200 Subject: [PATCH 04/10] Export TGCN --- GNNLux/src/GNNLux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 561843477..f9d944508 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -36,7 +36,7 @@ export AGNNConv, # TransformerConv include("layers/temporalconv.jl") -export TGCNCell +export TGCN end #module \ No newline at end of file From c06a465ae1d6005eb57616f271b727d2257d7e7f Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 15 Aug 2024 15:29:18 +0200 Subject: [PATCH 05/10] Fixes --- GNNLux/src/layers/temporalconv.jl | 45 +++++++++++++++++-------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 9cf7fe281..f9180dc11 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -2,21 +2,25 @@ cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer} end -function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell) - return (cell=initialstates(rng, r.cell), carry=nothing) +function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell) + return (cell=LuxCore.initialstates(rng, r.cell), carry=nothing) end -function (r::StatefulRecurrentCell)(g, x, ps, st::NamedTuple) - (out, carry), st_ = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry) - return out, (; cell=st_, carry) +function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple) + (out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry) + return out, (; cell=st, carry) end -function applyrecurrentcell(l, g, x, ps, st, carry) - return Lux.apply(l, g, (x, carry), ps, st) +function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple) + (out, carry), st = applyrecurrentcell(r.cell, g, first(x), ps, st.cell, st.carry) + for xᵢ in x[(begin + 1):end] + (out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry) + end + return out, (; cell=st, carry) end -function applyrecurrentcell(l, g, x, ps, st, ::Nothing) - return Lux.apply(l, g, x, ps, st) +function applyrecurrentcell(l, g, x, ps, st, carry) + return Lux.apply(l, g, (x, carry), ps, st) end LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st) @@ -26,29 +30,30 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st) out_dims::Int conv gru + init_state::Function end function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) in_dims, out_dims = ch conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true) gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) - return TGCNCell(in_dims, out_dims, conv, gru) + return TGCNCell(in_dims, out_dims, conv, gru, init_state) end -LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) - -function (l::TGCNCell)(g, x, h, ps, st) - conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv)) - gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - #m = (; conv, gru) - +function (l::TGCNCell)(g, (x, h), ps, st) + if h === nothing + h = l.init_state(l.out_dims, 1) + end x̃, stconv = l.conv(g, x, ps.conv, st.conv) - (h, (h,)), st = l.gru((x̃,(h,)), ps.gru,st.gru) - return (h, (h,)), st + (h, (h,)), stgru = l.gru((x̃,(h,)), ps.gru,st.gru) + return (h, h), (conv=stconv, gru=stgru) end +LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) +LuxCore.outputsize(l::GNNLux.StatefulRecurrentCell) = (l.cell.out_dims,) + function Base.show(io::IO, tgcn::TGCNCell) print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") end -tgcn = StatefulRecurrentCell(TGCNCell(1 =>3)) \ No newline at end of file +TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...)) \ No newline at end of file From d700acea5e21673e41e54551249f63a1870805b6 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 15 Aug 2024 15:29:39 +0200 Subject: [PATCH 06/10] Back to previous version --- src/layers/temporalconv.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index c8c74b907..aab7ab02a 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -35,7 +35,9 @@ function TGCNCell(ch::Pair{Int, Int}; end function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) - return GNNlib.tgcn_conv(tgcn, h, g, x) + x̃ = l.conv(g, x) + h, x̃ = l.gru(h, x̃) + return h, x̃ end function Base.show(io::IO, tgcn::TGCNCell) From 852ec8467447bd106b065fd3459238ec07eaf377 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 15 Aug 2024 15:29:50 +0200 Subject: [PATCH 07/10] Add test --- GNNLux/test/layers/temporalconv_test.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 GNNLux/test/layers/temporalconv_test.jl diff --git a/GNNLux/test/layers/temporalconv_test.jl b/GNNLux/test/layers/temporalconv_test.jl new file mode 100644 index 000000000..bdde7b325 --- /dev/null +++ b/GNNLux/test/layers/temporalconv_test.jl @@ -0,0 +1,15 @@ +@testitem "layers/temporalconv" setup=[SharedTestSetup] begin + using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + + rng = StableRNG(1234) + g = rand_graph(10, 40, seed=1234) + x = randn(rng, Float32, 3, 10) + + @testset "TGCN" begin + l = TGCN(3=>3) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + end +end \ No newline at end of file From e509abffa1639a8b1639c9add31fa8828afd293c Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 18 Aug 2024 09:34:30 +0200 Subject: [PATCH 08/10] Remove GNNlib code --- GNNlib/src/layers/temporalconv.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/GNNlib/src/layers/temporalconv.jl b/GNNlib/src/layers/temporalconv.jl index 13198cdee..8cff3f033 100644 --- a/GNNlib/src/layers/temporalconv.jl +++ b/GNNlib/src/layers/temporalconv.jl @@ -1,12 +1,3 @@ -####################### TGCN ###################################### - -function tgcn_conv(l, h, g::GNNGraph, x::AbstractArray) - x̃ = l.conv(g, x) - h, x̃ = l.gru(h, x̃) - return h, x̃ -end - - function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) h = a3tgcn.tgcn(g, x) e = a3tgcn.dense1(h) From e4b6f8017746c94f5a2a15da5b67f3a919b30f6b Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 18 Aug 2024 09:58:09 +0200 Subject: [PATCH 09/10] Fix --- src/layers/temporalconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index aab7ab02a..44688cea4 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -35,8 +35,8 @@ function TGCNCell(ch::Pair{Int, Int}; end function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) - x̃ = l.conv(g, x) - h, x̃ = l.gru(h, x̃) + x̃ = tgcn.conv(g, x) + h, x̃ = tgcn.gru(h, x̃) return h, x̃ end From 7d782000870bf8e7d20a45b07022288be5befcb5 Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:29:05 +0200 Subject: [PATCH 10/10] Fix Co-authored-by: Carlo Lucibello --- GNNLux/src/layers/temporalconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index f9180dc11..2ab6235b3 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -12,8 +12,8 @@ function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple) end function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple) - (out, carry), st = applyrecurrentcell(r.cell, g, first(x), ps, st.cell, st.carry) - for xᵢ in x[(begin + 1):end] + st, carry = st.cell, st.carry + for xᵢ in x (out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry) end return out, (; cell=st, carry)