From cafc1bcc25305c5d44a6855e67cab7666df005bb Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:14:00 +0200 Subject: [PATCH] Add `DCGRU` temporal layer (#448) * Add `DCGRU` code * Add `DCGRU` tests * Add export * Add docs * Update src/layers/temporalconv.jl --------- Co-authored-by: Carlo Lucibello --- src/GraphNeuralNetworks.jl | 1 + src/layers/temporalconv.jl | 83 +++++++++++++++++++++++++++++++++++++ test/layers/temporalconv.jl | 8 ++++ 3 files changed, 92 insertions(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 030fd46df..0471e6555 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -75,6 +75,7 @@ export A3TGCN, GConvLSTM, GConvGRU, + DCGRU, # layers/pool GlobalPool, diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 23df990aa..44688cea4 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -401,6 +401,89 @@ Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) +struct DCGRUCell + in::Int + out::Int + state0 + k::Int + dconv_u::DConv + dconv_r::DConv + dconv_c::DConv +end + +Flux.@functor DCGRUCell + +function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + dconv_u = DConv((in + out) => out, k; bias=bias, init=init) + dconv_r = DConv((in + out) => out, k; bias=bias, init=init) + dconv_c = DConv((in + out) => out, k; bias=bias, init=init) + state0 = init_state(out, n) + return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) +end + +function (dcgru::DCGRUCell)(h, g::GNNGraph, x) + h̃ = vcat(x, h) + z = dcgru.dconv_u(g, h̃) + z = NNlib.sigmoid_fast.(z) + r = dcgru.dconv_r(g, h̃) + r = NNlib.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = dcgru.dconv_c(g, ĥ) + c = tanh.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + +function Base.show(io::IO, dcgru::DCGRUCell) + print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") +end + +""" + DCGRU(in => out, k, n; [bias, init, init_state]) + +Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural +Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). + +Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Diffusion step. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); + +julia> y = dcgru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = dcgru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) +Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) + +(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index 2bb7859f6..b55aff808 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -61,6 +61,14 @@ end @test model(g1) isa GNNGraph end +@testset "DCGRU" begin + dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + @testset "GINConv" begin ginconv = GINConv(Dense(in_channel => out_channel),0.3) @test length(ginconv(tg, tg.ndata.x)) == S