diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 54fcfc14f..1e7603897 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -149,57 +149,6 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) g.ndata, g.edata, g.gdata) end -""" - remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) - -Remove specified edges from a GNNGraph. - -# Arguments -- `g`: The input graph from which edges will be removed. -- `edges_to_remove`: Vector of edge indices to be removed. - -# Returns -A new GNNGraph with the specified edges removed. - -# Example -```julia -julia> using GraphNeuralNetworks - -# Construct a GNNGraph -julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) -GNNGraph: - num_nodes: 3 - num_edges: 5 - -# Remove the second edge -julia> g_new = remove_edges(g, [2]); - -julia> g_new -GNNGraph: - num_nodes: 3 - num_edges: 4 -``` -""" -function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - - mask_to_keep = trues(length(s)) - - mask_to_keep[edges_to_remove] .= false - - s = s[mask_to_keep] - t = t[mask_to_keep] - edata = getobs(edata, mask_to_keep) - w = isnothing(w) ? nothing : getobs(w, mask_to_keep) - - return GNNGraph((s, t, w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) -end - """ remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) remove_edges(g::GNNGraph, p::Float64=0.5) @@ -275,6 +224,45 @@ function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) g.ndata, edata, g.gdata) end +""" + remove_multi_edges(g::GNNGraph; aggr=+) + +Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. +Possible edge features are aggregated according to `aggr`, that can take value +`+`,`min`, `max` or `mean`. + +See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). +""" +function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + num_edges = g.num_edges + idxs, idxmax = edge_encoding(s, t, g.num_nodes) + + perm = sortperm(idxs) + idxs = idxs[perm] + s, t = s[perm], t[perm] + edata = getobs(edata, perm) + w = isnothing(w) ? nothing : getobs(w, perm) + idxs = [-1; idxs] + mask = idxs[2:end] .> idxs[1:(end - 1)] + if !all(mask) + s, t = s[mask], t[mask] + idxs = similar(s, num_edges) + idxs .= 1:num_edges + idxs .= idxs .- cumsum(.!mask) + num_edges = length(s) + w = _scatter(aggr, w, idxs, num_edges) + edata = _scatter(aggr, edata, idxs, num_edges) + end + + return GNNGraph((s, t, w), + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + """ remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector)