diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ee868b6b6..e4d1c09aa 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -26,12 +26,12 @@ export AGNNConv, GCNConv, # GINConv, # GMMConv, - GraphConv + GraphConv, # MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, - # SGConv, + SGConv # TAGConv, # TransformerConv diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 15b1bbf4b..672bbe20a 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ", negative_slope=", l.negative_slope) print(io, ")") +end + +@concrete struct SGConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias +end + +function SGConv(ch::Pair{Int, Int}, k = 1; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false) + in_dims, out_dims = ch + return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.statelength(d::SGConv) = 0 +LuxCore.outputsize(d::SGConv) = (d.out_dims,) + +function Base.show(io::IO, l::SGConv) + print(io, "SGConv(", l.in_dims, " => ", l.out_dims) + l.k || print(io, ", ", l.k) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) = + l(g, x, edge_weight, ps, st; conv_weight) + +function (l::SGConv)(g, x, edge_weight, ps, st; + conv_weight=nothing, ) + + m = (; ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.k) + y = GNNlib.sg_conv(m, g, x, edge_weight) + return y, st end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index b2e81173d..2f18103e1 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -77,5 +77,10 @@ #TODO test edge end + + @testset "SGConv" begin + l = SGConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 2fb5bc44f..91a80ff0f 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -722,4 +722,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) T1_out = T2_out end return h .+ l.bias -end +end \ No newline at end of file