From fbbe84d004f9d6b146b18a67bb96ecbfa91d9422 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 29 Jul 2024 20:31:56 +0200 Subject: [PATCH] First draft --- GNNLux/src/GNNLux.jl | 4 +++- GNNLux/src/layers/temporalconv.jl | 26 ++++++++++++++++++++++++++ GNNlib/src/GNNlib.jl | 2 +- GNNlib/src/layers/temporalconv.jl | 9 +++++++++ src/layers/temporalconv.jl | 4 +--- 5 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 GNNLux/src/layers/temporalconv.jl diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ecac67b5a..561843477 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -34,7 +34,9 @@ export AGNNConv, # SGConv, # TAGConv, # TransformerConv - + +include("layers/temporalconv.jl") +export TGCNCell end #module \ No newline at end of file diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl new file mode 100644 index 000000000..01788bcf7 --- /dev/null +++ b/GNNLux/src/layers/temporalconv.jl @@ -0,0 +1,26 @@ +@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)} + in_dims::Int + out_dims::Int + conv + gru +end + +function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) + in_dims, out_dims = ch + conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true) + gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) + return TGCNCell(in_dims, out_dims, conv, gru) +end + +LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) + +function (l::TGCNCell)(h, g, x) + conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv)) + gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) + m = (; conv, gru) + return GNNlib.tgcn_conv(m, h, g, x) +end + +function Base.show(io::IO, tgcn::TGCNCell) + print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") +end \ No newline at end of file diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl index a84253776..3ca4acc9b 100644 --- a/GNNlib/src/GNNlib.jl +++ b/GNNlib/src/GNNlib.jl @@ -61,7 +61,7 @@ export agnn_conv, transformer_conv include("layers/temporalconv.jl") -export a3tgcn_conv +export tgcn_conv include("layers/pool.jl") export global_pool, diff --git a/GNNlib/src/layers/temporalconv.jl b/GNNlib/src/layers/temporalconv.jl index 8cff3f033..13198cdee 100644 --- a/GNNlib/src/layers/temporalconv.jl +++ b/GNNlib/src/layers/temporalconv.jl @@ -1,3 +1,12 @@ +####################### TGCN ###################################### + +function tgcn_conv(l, h, g::GNNGraph, x::AbstractArray) + x̃ = l.conv(g, x) + h, x̃ = l.gru(h, x̃) + return h, x̃ +end + + function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) h = a3tgcn.tgcn(g, x) e = a3tgcn.dense1(h) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 44688cea4..c8c74b907 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -35,9 +35,7 @@ function TGCNCell(ch::Pair{Int, Int}; end function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray) - x̃ = tgcn.conv(g, x) - h, x̃ = tgcn.gru(h, x̃) - return h, x̃ + return GNNlib.tgcn_conv(tgcn, h, g, x) end function Base.show(io::IO, tgcn::TGCNCell)