Skip to content

Commit

Permalink
fix dense test (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Aug 3, 2024
1 parent 83b6b7e commit ef22e9a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata
else
graph = g.graph
end
GNNGraph(graph,
g.num_nodes, g.num_edges, g.num_graphs,
g.graph_indicator,
ndata, edata, gdata)
return GNNGraph(graph,
g.num_nodes, g.num_edges, g.num_graphs,
g.graph_indicator,
ndata, edata, gdata)
end

"""
Expand Down
11 changes: 7 additions & 4 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end
# when we also have edge_weight we need to convert the graph to COO
function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where
{EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F}
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
g = GNNGraph(g, graph_type = :coo)
return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight)
end

Expand Down Expand Up @@ -449,9 +449,10 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
return (x .+ l.bias)
end

# when we also have edge_weight we need to convert the graph to COO
function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes)
g = GNNGraph(g; graph_type=:coo)
return sgc_conv(l, g, x, edge_weight)
end

Expand Down Expand Up @@ -542,9 +543,10 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
return (x .+ l.bias)
end

# when we also have edge_weight we need to convert the graph to COO
function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes)
g = GNNGraph(g; graph_type=:coo)
return sg_conv(l, g, x, edge_weight)
end

Expand Down Expand Up @@ -684,9 +686,10 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
return (sum_total .+ l.bias)
end

# when we also have edge_weight we need to convert the graph to COO
function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes)
g = GNNGraph(g; graph_type = :coo)
return l(g, x, edge_weight)
end

Expand Down

0 comments on commit ef22e9a

Please sign in to comment.