Skip to content

Commit

Permalink
Add DCGRU temporal layer (#448)
Browse files Browse the repository at this point in the history
* Add `DCGRU` code

* Add `DCGRU` tests

* Add export

* Add docs

* Update src/layers/temporalconv.jl

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
aurorarossi and CarloLucibello authored Jul 25, 2024
1 parent f6b95fc commit cafc1bc
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export
A3TGCN,
GConvLSTM,
GConvGRU,
DCGRU,

# layers/pool
GlobalPool,
Expand Down
83 changes: 83 additions & 0 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,89 @@ Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)

struct DCGRUCell
in::Int
out::Int
state0
k::Int
dconv_u::DConv
dconv_r::DConv
dconv_c::DConv
end

Flux.@functor DCGRUCell

function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
in, out = ch
dconv_u = DConv((in + out) => out, k; bias=bias, init=init)
dconv_r = DConv((in + out) => out, k; bias=bias, init=init)
dconv_c = DConv((in + out) => out, k; bias=bias, init=init)
state0 = init_state(out, n)
return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c)
end

function (dcgru::DCGRUCell)(h, g::GNNGraph, x)
= vcat(x, h)
z = dcgru.dconv_u(g, h̃)
z = NNlib.sigmoid_fast.(z)
r = dcgru.dconv_r(g, h̃)
r = NNlib.sigmoid_fast.(r)
= vcat(x, h .* r)
c = dcgru.dconv_c(g, ĥ)
c = tanh.(c)
h = z.* h + (1 .- z) .* c
return h, h
end

function Base.show(io::IO, dcgru::DCGRUCell)
print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))")
end

"""
DCGRU(in => out, k, n; [bias, init, init_state])
Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural
Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `k`: Diffusion step.
- `n`: Number of nodes in the graph.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
# Examples
```jldoctest
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes);
julia> y = dcgru(g1, x1);
julia> size(y)
(5, 5)
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
julia> z = dcgru(g2, x2);
julia> size(z)
(5, 5, 30)
```
"""
DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)

(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
end
Expand Down
8 changes: 8 additions & 0 deletions test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ end
@test model(g1) isa GNNGraph
end

@testset "DCGRU" begin
dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes)
@test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
@test size(model(g1, g1.ndata.x)) == (1, N)
@test model(g1) isa GNNGraph
end

@testset "GINConv" begin
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
@test length(ginconv(tg, tg.ndata.x)) == S
Expand Down

0 comments on commit cafc1bc

Please sign in to comment.