From 9d741a8d1e41b139d03c00a31f548fdeaa1d926d Mon Sep 17 00:00:00 2001 From: Ghaithq Date: Sun, 11 Feb 2024 15:58:29 +0200 Subject: [PATCH] extended GatedGraphConv and NNConv to use AbstactGNNGraph --- src/layers/conv.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 006b03091..e394b4cd0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -589,7 +589,7 @@ end # remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 @non_differentiable fill!(x...) -function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real} +function (l::GatedGraphConv)(g::AbstractGNNGraph, H::AbstractMatrix{S}) where {S <: Real} check_num_nodes(g, H) m, n = size(H) @assert (m<=l.out_ch) "number of input features must less or equals to output features." @@ -739,11 +739,11 @@ function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, return NNConv(W, b, nn, σ, aggr) end -function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e) +function (l::NNConv)(g::AbstractGNNGraph, x, e) check_num_nodes(g, x) - - m = propagate(message, g, l.aggr, l, xj = x, e = e) - return l.σ.(l.weight * x .+ m .+ l.bias) + xj, xi = expand_srcdst(g, x) + m = propagate(message, g, l.aggr, l, xj = xj, e = e) + return l.σ.(l.weight * xi .+ m .+ l.bias) end function message(l::NNConv, xi, xj, e)