From 4fed7e8b01ee317c7c3b03383a8b10113fb22aaa Mon Sep 17 00:00:00 2001 From: rbSparky Date: Fri, 15 Apr 2022 11:46:20 +0530 Subject: [PATCH 1/3] Added main code --- src/GraphNeuralNetworks.jl | 1 + src/layers/conv.jl | 66 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 11c3e4039..a680f77aa 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -61,6 +61,7 @@ export ResGatedGraphConv, SAGEConv, GMMConv, + EdgeWeightNorm, # layers/pool GlobalPool, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7bfbe9592..62a2123bc 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1181,3 +1181,69 @@ function Base.show(io::IO, l::GMMConv) l.residual==true || print(io, ", residual=", l.residual) print(io, ")") end + +@doc raw""" + EdgeWeightNorm(norm_both = true, eps = 0) + +Normalizes positive scalar edge weights on a graph following the form in GCN. + +norm_both = `true` yields the following normalization term: +```math +c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}}) +``` +norm_both = `false` yields the following normalization term: +```math +c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki}) +``` +where ``e_ji`` is the scalar weight on the edge from node j to node i. + +Return value is the normalized weight ``e_{ji} / c_{ji}`` for all edges in vector form. + +# Arguments + +- `norm_both`: The normalizer as specified above. Default is `true`. +- `eps`: Offset value in the denominator. Default is `0`. + +# Examples + +```julia +# create data +g = GNNGraph([1,2,3,4,3,6], [2,3,4,5,1,4]) +g = add_self_loops(g) + +# edge weights +edge_weights = [0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1] + +l = EdgeWeightNorm() +l(g, edge_weights) +``` +""" +struct EdgeWeightNorm <: GNNLayer + norm_both::Bool + eps::Float64 +end + +@functor EdgeWeightNorm + +function EdgeWeightNorm(norm_both::Bool = true, + eps::Float64 = 0) + EdgeWeightNorm(norm_both, eps) +end + +function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::AbstractVector) + norm_val = Vector{Float64}() + edge_in, edge_out = edge_index(g) + + dg_in = degree(g; dir = :in, edge_weight) + dg_out = degree(g; dir = :out, edge_weight) + + for iter in 1:length(edge_weight) + if l.norm_both + push!(norm_val, edge_weight[iter] / (sqrt(dg_out[in[iter]] * dg_in[out[iter]]) + l.eps)) + else + push!(norm_val, edge_weight[iter] / (dg_in[out[iter]] + l.eps)) + end + end + + return norm_val +end From 7ca7589aaec43bd2189fd068599ef83f66b9a66b Mon Sep 17 00:00:00 2001 From: Rishabh B <59335537+rbSparky@users.noreply.github.com> Date: Fri, 15 Apr 2022 11:57:04 +0530 Subject: [PATCH 2/3] Small doc fix --- src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 62a2123bc..25da7454f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1195,7 +1195,7 @@ norm_both = `false` yields the following normalization term: ```math c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki}) ``` -where ``e_ji`` is the scalar weight on the edge from node j to node i. +where ``e_{ji}`` is the scalar weight on the edge from node j to node i. Return value is the normalized weight ``e_{ji} / c_{ji}`` for all edges in vector form. From 71f5455bed75ea058f782ec0e94268cd70cf650c Mon Sep 17 00:00:00 2001 From: rbSparky Date: Tue, 3 May 2022 10:15:59 +0530 Subject: [PATCH 3/3] Changed output type to be same as edge_weight --- src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 62a2123bc..80b4058e8 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1230,8 +1230,8 @@ function EdgeWeightNorm(norm_both::Bool = true, EdgeWeightNorm(norm_both, eps) end -function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::AbstractVector) - norm_val = Vector{Float64}() +function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::T) where T <: AbstractVector + norm_val = T() edge_in, edge_out = edge_index(g) dg_in = degree(g; dir = :in, edge_weight)