From b55d31de8fd55b1c05ae5bfd7d37ad47c408c7bb Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 18:13:44 +0530 Subject: [PATCH 1/5] added sgconv lux --- GNNLux/src/GNNLux.jl | 4 +-- GNNLux/src/layers/conv.jl | 56 ++++++++++++++++++++++++++++++++ GNNLux/test/layers/conv_tests.jl | 5 +++ GNNlib/src/layers/conv.jl | 54 ++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ee868b6b6..e4d1c09aa 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -26,12 +26,12 @@ export AGNNConv, GCNConv, # GINConv, # GMMConv, - GraphConv + GraphConv, # MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, - # SGConv, + SGConv # TAGConv, # TransformerConv diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 15b1bbf4b..26af1cf7e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ", negative_slope=", l.negative_slope) print(io, ")") +end + +@concrete struct SGConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias +end + +function SGConv(ch::Pair{Int, Int}, k = 1; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false) + in_dims, out_dims = ch + return SGConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, k) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.statelength(d::SGConv) = 0 +LuxCore.outputsize(d::SGConv) = (d.out_dims,) + +function Base.show(io::IO, l::SGConv) + print(io, "SGConv(", l.in_dims, " => ", l.out_dims) + l.k || print(io, ", ", l.k) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) = + l(g, x, edge_weight, ps, st; conv_weight) + +function (l::SGConv)(g, x, edge_weight, ps, st; + conv_weight=nothing, ) + + m = (; ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) + y = GNNlib.sg_conv(m, g, x, edge_weight, conv_weight) + return y, st end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index b2e81173d..a4e2fdf79 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -77,5 +77,10 @@ #TODO test edge end + + @testset "SGConv" begin + l = SGConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 2fb5bc44f..3e97fae75 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -723,3 +723,57 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) end return h .+ l.bias end + +####################### GCNConv ###################################### + +function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + if edge_weight !== nothing + @assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if conv_weight === nothing + weight = l.weight + else + weight = conv_weight + if size(weight) != size(l.weight) + throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) + end + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + d = degree(g, T; dir=:in, edge_weight) + c = 1 ./ sqrt.(d) + for iter in 1:l.k + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj=x) + else + x = propagate(copy_xj, g, +, xj=x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO + return gcn_conv(l, g, x, edge_weight, conv_weight) +end \ No newline at end of file From 96c4092bf2011331947a502c42df959225344240 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 18:17:40 +0530 Subject: [PATCH 2/5] fix --- GNNlib/src/layers/conv.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 3e97fae75..0f345881e 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -724,7 +724,7 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) return h .+ l.bias end -####################### GCNConv ###################################### +####################### SGConv ###################################### function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) where {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} @@ -772,8 +772,8 @@ function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) wh end # when we also have edge_weight we need to convert the graph to COO -function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where +function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return gcn_conv(l, g, x, edge_weight, conv_weight) + return sg_conv(l, g, x, edge_weight, conv_weight) end \ No newline at end of file From de9a39a34669fbca3a7f1c4bf04260f8b466c107 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 19:50:23 +0530 Subject: [PATCH 3/5] fix --- GNNLux/src/layers/conv.jl | 2 +- GNNLux/test/layers/conv_tests.jl | 2 +- GNNlib/src/layers/conv.jl | 54 -------------------------------- 3 files changed, 2 insertions(+), 56 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 26af1cf7e..aabbbfb24 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -535,7 +535,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1; add_self_loops::Bool = true, use_edge_weight::Bool = false) in_dims, out_dims = ch - return SGConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, k) + return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias) end function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index a4e2fdf79..2f18103e1 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -79,7 +79,7 @@ end @testset "SGConv" begin - l = SGConv(in_dims => out_dims, relu) + l = SGConv(in_dims => out_dims, 2) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0f345881e..91a80ff0f 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -722,58 +722,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) T1_out = T2_out end return h .+ l.bias -end - -####################### SGConv ###################################### - -function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) where - {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - if edge_weight !== nothing - @assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if conv_weight === nothing - weight = l.weight - else - weight = conv_weight - if size(weight) != size(l.weight) - throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) - end - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if Dout < Din - x = l.weight * x - end - d = degree(g, T; dir=:in, edge_weight) - c = 1 ./ sqrt.(d) - for iter in 1:l.k - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj=x) - else - x = propagate(copy_xj, g, +, xj=x) - end - x = x .* c' - end - if Dout >= Din - x = l.weight * x - end - return (x .+ l.bias) -end - -# when we also have edge_weight we need to convert the graph to COO -function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where - {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return sg_conv(l, g, x, edge_weight, conv_weight) end \ No newline at end of file From 54957daa836ddb0dec5b1845d82dc4b8fc2c25ad Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 20:02:40 +0530 Subject: [PATCH 4/5] fix --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index aabbbfb24..2a44651b2 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -568,7 +568,7 @@ function (l::SGConv)(g, x, edge_weight, ps, st; conv_weight=nothing, ) m = (; ps.weight, bias = _getbias(ps), - l.add_self_loops, l.use_edge_weight, l.σ) + l.add_self_loops, l.use_edge_weight, l.k) y = GNNlib.sg_conv(m, g, x, edge_weight, conv_weight) return y, st end \ No newline at end of file From a2ec7a38647ed87ba544cdcd715f058f8ab2c114 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 20:31:25 +0530 Subject: [PATCH 5/5] fix --- GNNLux/src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 2a44651b2..672bbe20a 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -569,6 +569,6 @@ function (l::SGConv)(g, x, edge_weight, ps, st; m = (; ps.weight, bias = _getbias(ps), l.add_self_loops, l.use_edge_weight, l.k) - y = GNNlib.sg_conv(m, g, x, edge_weight, conv_weight) + y = GNNlib.sg_conv(m, g, x, edge_weight) return y, st end \ No newline at end of file