diff --git a/GNNGraphs/src/gnngraph.jl b/GNNGraphs/src/gnngraph.jl index 64fd32aad..a9af576e2 100644 --- a/GNNGraphs/src/gnngraph.jl +++ b/GNNGraphs/src/gnngraph.jl @@ -209,10 +209,10 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata else graph = g.graph end - GNNGraph(graph, - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - ndata, edata, gdata) + return GNNGraph(graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, gdata) end """ diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 50b5b34aa..3a5c543a1 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -74,7 +74,7 @@ end # when we also have edge_weight we need to convert the graph to COO function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO + g = GNNGraph(g, graph_type = :coo) return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) end @@ -449,9 +449,10 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (x .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type=:coo) return sgc_conv(l, g, x, edge_weight) end @@ -542,9 +543,10 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (x .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type=:coo) return sg_conv(l, g, x, edge_weight) end @@ -684,9 +686,10 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (sum_total .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type = :coo) return l(g, x, edge_weight) end