Skip to content

Commit

Permalink
GNNRecurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 16, 2024
1 parent 02919ac commit cfd9ec3
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 195 deletions.
5 changes: 3 additions & 2 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ include("layers/heteroconv.jl")
export HeteroGraphConv

include("layers/temporalconv.jl")
export GConvGRU, GConvGRUCell,
export GNNRecurrence,
GConvGRU, GConvGRUCell,
GConvLSTM, GConvLSTMCell,
DCGRU, DCGRUCell,
TGCN,
A3TGCN,
DCGRU,
EvolveGCNO

include("layers/pool.jl")
Expand Down
281 changes: 217 additions & 64 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,73 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
return stack(y, dims = 2)
end


"""
GNNRecurrence(cell)
Construct a recurrent layer that applies the `cell`
to process an entire temporal sequence of node features at once.
# Forward
layer(g::GNNGraph, x, [state])
- `g`: The input graph.
- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
- `state`: The initial state of the cell.
If not provided, it is generated by calling `Flux.initialstates(cell)`.
Applies the recurrent cell to each timestep of the input sequence and returns the output as
an array of size `out_features x timesteps x num_nodes`.
# Examples
```jldoctest
julia> num_nodes, num_edges = 5, 10;
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> cell = GConvLSTMCell(d_in => d_out, 2)
GConvLSTMCell(2 => 3, 2) # 168 parameters
julia> layer = GNNRecurrence(cell)
GNNRecurrence(
GConvLSTMCell(2 => 3, 2), # 168 parameters
) # Total: 24 arrays, 168 parameters, 2.023 KiB.
julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
struct GNNRecurrence{G} <: GNNLayer
cell::G
end

Flux.@layer GNNRecurrence

Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell)

function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}) where {T}
return rnn(g, x, initialstates(rnn))
end

function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
return scan(rnn.cell, g, x, state)
end

function Base.show(io::IO, rnn::GNNRecurrence)
print(io, "GNNRecurrence($(rnn.cell))")
end


"""
GConvGRUCell(in => out, k; [bias, init])
Expand Down Expand Up @@ -126,24 +193,13 @@ function Base.show(io::IO, cell::GConvGRUCell)
end

"""
GConvGRU(in => out, k; kws...)
The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell,
used to process an entire temporal sequence of node features at once.
GConvGRU(args...; kws...)
The arguments are the same as for [`GConvGRUCell`](@ref).
Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell.
It can be used to process an entire temporal sequence of node features at once.
# Forward
layer(g::GNNGraph, x, [h])
- `g`: The input graph.
- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
If not provided, it is assumed to be a matrix of zeros.
Applies the recurrent cell to each timestep of the input sequence and returns the output as
an array of size `out x timesteps x num_nodes`.
The arguments are passed to the [`GConvGRUCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.
# Examples
Expand All @@ -158,33 +214,18 @@ julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> layer = GConvGRU(d_in => d_out, 2);
julia> layer = GConvGRU(d_in => d_out, 2)
GConvGRU(
GConvGRUCell(2 => 3, 2), # 108 parameters
) # Total: 12 arrays, 108 parameters, 1.148 KiB.
julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
struct GConvGRU{G<:GConvGRUCell} <: GNNLayer
cell::G
end

Flux.@layer GConvGRU

function GConvGRU(ch::Pair{Int,Int}, k::Int; kws...)
return GConvGRU(GConvGRUCell(ch, k; kws...))
end

Flux.initialstates(rnn::GConvGRU) = Flux.initialstates(rnn.cell)

function (rnn::GConvGRU)(g::GNNGraph, x::AbstractArray)
return scan(rnn.cell, g, x, initialstates(rnn))
end

function Base.show(io::IO, rnn::GConvGRU)
print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
end
GConvGRU(args...; kws...) = GNNRecurrence(GConvGRUCell(args...; kws...))


"""
Expand Down Expand Up @@ -268,7 +309,7 @@ julia> size(y) # (d_out, num_nodes)
out::Int
end

Flux.@layer GConvLSTMCell
Flux.@layer :noexpand GConvLSTMCell

function GConvLSTMCell(ch::Pair{Int, Int}, k::Int;
bias::Bool = true,
Expand Down Expand Up @@ -305,6 +346,8 @@ function Flux.initialstates(cell::GConvLSTMCell)
(zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out))
end

(cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))

function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c))
if h isa AbstractVector
h = repeat(h, 1, g.num_nodes)
Expand Down Expand Up @@ -334,29 +377,74 @@ end


"""
GConvLSTM(in => out, k; kws...)
GConvLSTM(args...; kws...)
Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell.
It can be used to process an entire temporal sequence of node features at once.
The arguments are passed to the [`GConvLSTMCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.
# Examples
```jldoctest
julia> num_nodes, num_edges = 5, 10;
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> layer = GConvLSTM(d_in => d_out, 2)
GNNRecurrence(
GConvLSTMCell(2 => 3, 2), # 168 parameters
) # Total: 24 arrays, 168 parameters, 2.023 KiB.
julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
GConvLSTM(args...; kws...) = GNNRecurrence(GConvLSTMCell(args...; kws...))

"""
DCGRUCell(in => out, k; [bias, init])
Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper
[Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926).
Applyis a [`DConv`](@ref) layer to model spatial dependencies,
in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,
used to process an entire temporal sequence of node features at once.
# Arguments
The arguments are the same as for [`GConvLSTMCell`](@ref).
- `in`: Number of input node features.
- `out`: Number of output node features.
- `k`: Diffusion step for the `DConv`.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
# Forward
layer(g::GNNGraph, x, [state])
cell(g::GNNGraph, x, [h])
- `g`: The input graph.
- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
- `state`: The initial hidden state of the LSTM cell.
If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`.
If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
If not provided, it is assumed to be a matrix of zeros.
Applies the recurrent cell to each timestep of the input sequence and returns the output as
an array of size `out x timesteps x num_nodes`.
Performs one recurrence step and returns a tuple `(h, h)`,
where `h` is the updated hidden state of the GRU cell.
# Examples
```jldoctest
julia> using GraphNeuralNetworks, Flux
julia> num_nodes, num_edges = 5, 10;
julia> d_in, d_out = 2, 3;
Expand All @@ -365,33 +453,98 @@ julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
julia> layer = GConvLSTM(d_in => d_out, 2);
julia> cell = DCGRUCell(d_in => d_out, 2);
julia> y = layer(g, x);
julia> state = Flux.initialstates(cell);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
julia> y = state;
julia> for xt in x
y, state = cell(g, xt, state)
end
julia> size(y) # (d_out, num_nodes)
(3, 5)
```
"""
struct GConvLSTM{G<:GConvLSTMCell} <: GNNLayer
cell::G
"""
struct DCGRUCell
in::Int
out::Int
k::Int
dconv_u::DConv
dconv_r::DConv
dconv_c::DConv
end

Flux.@layer GConvLSTM
Flux.@layer :noexpand DCGRUCell

function GConvLSTM(ch::Pair{Int,Int}, k::Int; kws...)
return GConvLSTM(GConvLSTMCell(ch, k; kws...))
function DCGRUCell(ch::Pair{Int,Int}, k::Int; bias = true, init = glorot_uniform)
in, out = ch
dconv_u = DConv((in + out) => out, k; bias, init)
dconv_r = DConv((in + out) => out, k; bias, init)
dconv_c = DConv((in + out) => out, k; bias, init)
return DCGRUCell(in, out, k, dconv_u, dconv_r, dconv_c)
end

Flux.initialstates(rnn::GConvLSTM) = Flux.initialstates(rnn.cell)
Flux.initialstates(cell::DCGRUCell) = zeros_like(cell.dconv_u.weights, cell.out)

(cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))

function (rnn::GConvLSTM)(g::GNNGraph, x::AbstractArray)
return scan(rnn.cell, g, x, initialstates(rnn))
function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
return cell(g, x, repeat(h, 1, g.num_nodes))
end

function Base.show(io::IO, rnn::GConvLSTM)
print(io, "GConvLSTM($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))")
function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
= vcat(x, h)
z = cell.dconv_u(g, h̃)
z = NNlib.sigmoid_fast.(z)
r = cell.dconv_r(g, h̃)
r = NNlib.sigmoid_fast.(r)
= vcat(x, h .* r)
c = cell.dconv_c(g, ĥ)
c = NNlib.tanh_fast.(c)
h = z.* h + (1 .- z) .* c
return h, h
end

function Base.show(io::IO, cell::DCGRUCell)
print(io, "DCGRUCell($(cell.in) => $(cell.out), $(cell.k))")
end

"""
DCGRU(args...; kws...)
Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell.
It can be used to process an entire temporal sequence of node features at once.
The arguments are passed to the [`DCGRUCell`](@ref) constructor.
See [`GNNRecurrence`](@ref) for more details.
# Examples
```jldoctest
julia> num_nodes, num_edges = 5, 10;
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> layer = DCGRU(d_in => d_out, 2)
GNNRecurrence(
DCGRUCell(2 => 3, 2), # 189 parameters
) # Total: 6 arrays, 189 parameters, 1.184 KiB.
julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
```
"""
DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...))


Loading

0 comments on commit cfd9ec3

Please sign in to comment.