diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 8c98fc474..e0a199276 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -32,7 +32,7 @@ export AGNNConv, # GMMConv, GraphConv, MEGNetConv, - # NNConv, + NNConv, # ResGatedGraphConv, # SAGEConv, SGConv @@ -44,4 +44,4 @@ export TGCN export A3TGCN end #module - \ No newline at end of file + diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 30564ae48..cfe8157df 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -629,6 +629,44 @@ function Base.show(io::IO, l::GINConv) print(io, ")") end +@concrete struct NNConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + aggr + in_dims::Int + out_dims::Int + use_bias::Bool + init_weight + init_bias + σ +end + +function NNConv(ch::Pair{Int, Int}, nn, σ = identity; + aggr = +, + init_bias = zeros32, + use_bias::Bool = true, + init_weight = glorot_uniform, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) +end + +function (l::NNConv)(g, x, edge_weight, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ) + y = GNNlib.nn_conv(m, g, x, edge_weight) + stnew = _getstate(nn) + return y, stnew +end + +function Base.show(io::IO, l::NNConv) + print(io, "NNConv($(l.nn)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end + @concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} in_dims::Int out_dims::Int @@ -669,4 +707,4 @@ function Base.show(io::IO, l::MEGNetConv) nout = l.out_dims print(io, "MEGNetConv(", nin, " => ", nout) print(io, ")") -end \ No newline at end of file +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 86a056977..20908cb3a 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,10 +1,13 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) + edim = 10 g = rand_graph(rng, 10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) + g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges)) + @testset "GCNConv" begin l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) @@ -94,6 +97,32 @@ test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + + + @testset "NNConv" begin + n_in = 3 + n_in_edge = 10 + n_out = 5 + + s = [1,1,2,3] + t = [2,3,1,1] + g2 = GNNGraph(s, t) + + nn = Dense(n_in_edge => n_out * n_in) + l = NNConv(n_in => n_out, nn, tanh, aggr = +) + x = randn(Float32, n_in, g2.num_nodes) + e = randn(Float32, n_in_edge, g2.num_edges) + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + y = l(g2, x, e, ps, st) # just to see if it runs without an error + #edim = 10 + #nn = Dense(edim, in_dims * out_dims) + #l = NNConv(in_dims => out_dims, nn, tanh, aggr = +) + #test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e) + end + @testset "MEGNetConv" begin l = MEGNetConv(in_dims => out_dims) diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index b6b80df49..797e1577d 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -14,7 +14,7 @@ export test_lux_layer function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing) if container @test l isa GNNContainerLayer @@ -26,8 +26,13 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; st = LuxCore.initialstates(rng, l) @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) @test LuxCore.statelength(l) == LuxCore.statelength(st) - - y, st′ = l(g, x, ps, st) + + if edge_weight !== nothing + y, st′ = l(g, x, edge_weight, ps, st) + else + y, st′ = l(g, x, ps, st) + end + @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize @@ -42,4 +47,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end -end \ No newline at end of file +end