diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index e4d1c09aa..d8970095c 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,8 +1,11 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer -using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer +using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, + initialparameters, initialstates, parameterlength, statelength +using Lux: Lux, Chain, Dense, GRUCell, + glorot_uniform, zeros32, + StatefulLuxLayer using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib @@ -22,9 +25,9 @@ export AGNNConv, DConv, GATConv, GATv2Conv, - # GatedGraphConv, + GatedGraphConv, GCNConv, - # GINConv, + GINConv, # GMMConv, GraphConv, # MEGNetConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 672bbe20a..83c3efddc 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv) end LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims -LuxCore.statelength(d::GCNConv) = 0 LuxCore.outputsize(d::GCNConv) = (d.out_dims,) function Base.show(io::IO, l::GCNConv) @@ -549,7 +548,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) end LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims -LuxCore.statelength(d::SGConv) = 0 LuxCore.outputsize(d::SGConv) = (d.out_dims,) function Base.show(io::IO, l::SGConv) @@ -561,14 +559,72 @@ function Base.show(io::IO, l::SGConv) print(io, ")") end -(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) = - l(g, x, edge_weight, ps, st; conv_weight) - -function (l::SGConv)(g, x, edge_weight, ps, st; - conv_weight=nothing, ) +(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st) +function (l::SGConv)(g, x, edge_weight, ps, st) m = (; ps.weight, bias = _getbias(ps), l.add_self_loops, l.use_edge_weight, l.k) y = GNNlib.sg_conv(m, g, x, edge_weight) return y, st -end \ No newline at end of file +end + +@concrete struct GatedGraphConv <: GNNLayer + gru + init_weight + dims::Int + num_layers::Int + aggr +end + + +function GatedGraphConv(dims::Int, num_layers::Int; + aggr = +, init_weight = glorot_uniform) + gru = GRUCell(dims => dims) + return GatedGraphConv(gru, init_weight, dims, num_layers, aggr) +end + +LuxCore.outputsize(l::GatedGraphConv) = (l.dims,) + +function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv) + gru = LuxCore.initialparameters(rng, l.gru) + weight = l.init_weight(rng, l.dims, l.dims, l.num_layers) + return (; gru, weight) +end + +LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers + + +function (l::GatedGraphConv)(g, x, ps, st) + gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) + fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style + m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) + return GNNlib.gated_graph_conv(m, g, x), st +end + +function Base.show(io::IO, l::GatedGraphConv) + print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") + print(io, ", aggr=", l.aggr) + print(io, ")") +end + +@concrete struct GINConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + ϵ <: Real + aggr +end + +GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) + +function (l::GINConv)(g, x, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + m = (; nn, l.ϵ, l.aggr) + y = GNNlib.gin_conv(m, g, x) + stnew = _getstate(nn) + return y, stnew +end + +function Base.show(io::IO, l::GINConv) + print(io, "GINConv($(l.nn)") + print(io, ", $(l.ϵ)") + print(io, ")") +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 2f18103e1..9f010f39e 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -82,5 +82,15 @@ l = SGConv(in_dims => out_dims, 2) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end -end + @testset "GatedGraphConv" begin + l = GatedGraphConv(in_dims, 3) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) + end + + @testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) + end +end diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index 1354ef387..b6b80df49 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test LuxCore.statelength(l) == LuxCore.statelength(st) y, st′ = l(g, x, ps, st) + @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 91a80ff0f..50b5b34aa 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -28,7 +28,7 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w if edge_weight !== nothing # Pad weights with ones # TODO for ADJMAT_T the new edges are not generally at the end - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] @assert length(edge_weight) == g.num_edges end end @@ -215,23 +215,22 @@ end ####################### GatedGraphConv ###################################### -# TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 -@non_differentiable fill!(x...) - -function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real} - check_num_nodes(g, H) - m, n = size(H) - @assert (m<=l.out_ch) "number of input features must less or equals to output features." - if m < l.out_ch - Hpad = similar(H, S, l.out_ch - m, n) - H = vcat(H, fill!(Hpad, 0)) +function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + m, n = size(x) + @assert m <= l.dims "number of input features must be less or equal to output features." + if m < l.dims + xpad = zeros_like(x, (l.dims - m, n)) + x = vcat(x, xpad) end + h = x for i in 1:(l.num_layers) - M = view(l.weight, :, :, i) * H - M = propagate(copy_xj, g, l.aggr; xj = M) - H, _ = l.gru(H, M) + m = view(l.weight, :, :, i) * h + m = propagate(copy_xj, g, l.aggr; xj = m) + # in gru forward, hidden state is first argument, input is second + h, _ = l.gru(h, m) end - return H + return h end ####################### EdgeConv ###################################### @@ -419,7 +418,7 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, if l.add_self_loops g = add_self_loops(g) if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)] @assert length(edge_weight) == g.num_edges end end @@ -512,7 +511,7 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, if l.add_self_loops g = add_self_loops(g) if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] @assert length(edge_weight) == g.num_edges end end @@ -644,7 +643,7 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, if l.add_self_loops g = add_self_loops(g) if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)] @assert length(edge_weight) == g.num_edges end end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ec9268bd0..4a9f31783 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -486,7 +486,7 @@ where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing throu # Arguments - `out`: The dimension of output features. -- `num_layers`: The number of gated recurrent unit. +- `num_layers`: The number of recursion steps. - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). - `init`: Weight initialization function. @@ -510,25 +510,25 @@ y = l(g, x) struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer weight::W gru::R - out_ch::Int + dims::Int num_layers::Int aggr::A end @functor GatedGraphConv -function GatedGraphConv(out_ch::Int, num_layers::Int; +function GatedGraphConv(dims::Int, num_layers::Int; aggr = +, init = glorot_uniform) - w = init(out_ch, out_ch, num_layers) - gru = GRUCell(out_ch, out_ch) - GatedGraphConv(w, gru, out_ch, num_layers, aggr) + w = init(dims, dims, num_layers) + gru = GRUCell(dims => dims) + GatedGraphConv(w, gru, dims, num_layers, aggr) end (l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H) function Base.show(io::IO, l::GatedGraphConv) - print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)") + print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") print(io, ", aggr=", l.aggr) print(io, ")") end @@ -1201,7 +1201,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1; in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false - SGConv(W, b, k, add_self_loops, use_edge_weight) + return SGConv(W, b, k, add_self_loops, use_edge_weight) end (l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight)