diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ecac67b5a..f9d944508 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -34,7 +34,9 @@ export AGNNConv, # SGConv, # TAGConv, # TransformerConv - + +include("layers/temporalconv.jl") +export TGCN end #module \ No newline at end of file diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl new file mode 100644 index 000000000..2ab6235b3 --- /dev/null +++ b/GNNLux/src/layers/temporalconv.jl @@ -0,0 +1,59 @@ +@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} + cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer} +end + +function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell) + return (cell=LuxCore.initialstates(rng, r.cell), carry=nothing) +end + +function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple) + (out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry) + return out, (; cell=st, carry) +end + +function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple) + st, carry = st.cell, st.carry + for xᵢ in x + (out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry) + end + return out, (; cell=st, carry) +end + +function applyrecurrentcell(l, g, x, ps, st, carry) + return Lux.apply(l, g, (x, carry), ps, st) +end + +LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st) + +@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)} + in_dims::Int + out_dims::Int + conv + gru + init_state::Function +end + +function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) + in_dims, out_dims = ch + conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true) + gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) + return TGCNCell(in_dims, out_dims, conv, gru, init_state) +end + +function (l::TGCNCell)(g, (x, h), ps, st) + if h === nothing + h = l.init_state(l.out_dims, 1) + end + x̃, stconv = l.conv(g, x, ps.conv, st.conv) + (h, (h,)), stgru = l.gru((x̃,(h,)), ps.gru,st.gru) + return (h, h), (conv=stconv, gru=stgru) +end + +LuxCore.outputsize(l::TGCNCell) = (l.out_dims,) +LuxCore.outputsize(l::GNNLux.StatefulRecurrentCell) = (l.cell.out_dims,) + +function Base.show(io::IO, tgcn::TGCNCell) + print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") +end + +TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...)) \ No newline at end of file diff --git a/GNNLux/test/layers/temporalconv_test.jl b/GNNLux/test/layers/temporalconv_test.jl new file mode 100644 index 000000000..bdde7b325 --- /dev/null +++ b/GNNLux/test/layers/temporalconv_test.jl @@ -0,0 +1,15 @@ +@testitem "layers/temporalconv" setup=[SharedTestSetup] begin + using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + + rng = StableRNG(1234) + g = rand_graph(10, 40, seed=1234) + x = randn(rng, Float32, 3, 10) + + @testset "TGCN" begin + l = TGCN(3=>3) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + end +end \ No newline at end of file diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl index a84253776..3ca4acc9b 100644 --- a/GNNlib/src/GNNlib.jl +++ b/GNNlib/src/GNNlib.jl @@ -61,7 +61,7 @@ export agnn_conv, transformer_conv include("layers/temporalconv.jl") -export a3tgcn_conv +export tgcn_conv include("layers/pool.jl") export global_pool,