Skip to content

Commit

Permalink
new drop edge
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 1, 2024
1 parent 4b4477e commit d2ab349
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 26 deletions.
88 changes: 62 additions & 26 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,40 +201,76 @@ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:In
end

"""
remove_multi_edges(g::GNNGraph; aggr=+)
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
remove_edges(g::GNNGraph, p::Float64=0.5)
Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability.
# Arguments
- `g`: The input graph from which edges will be removed.
- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method.
- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5.
Remove multiple edges (also called parallel edges or repeated edges) from graph `g`.
Possible edge features are aggregated according to `aggr`, that can take value
`+`,`min`, `max` or `mean`.
# Returns
A new GNNGraph with the specified edges removed.
See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref).
# Example
```julia
julia> using GraphNeuralNetworks
# Construct a GNNGraph
julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
GNNGraph:
num_nodes: 3
num_edges: 5
# Remove the second edge
julia> g_new = remove_edges(g, [2]);
julia> g_new
GNNGraph:
num_nodes: 3
num_edges: 4
# Remove edges with a probability of 0.5
julia> g_new = remove_edges(g, 0.5);
julia> g_new
GNNGraph:
num_nodes: 3
num_edges: 2
```
"""
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata
num_edges = g.num_edges
idxs, idxmax = edge_encoding(s, t, g.num_nodes)

perm = sortperm(idxs)
idxs = idxs[perm]
s, t = s[perm], t[perm]
edata = getobs(edata, perm)
w = isnothing(w) ? nothing : getobs(w, perm)
idxs = [-1; idxs]
mask = idxs[2:end] .> idxs[1:(end - 1)]
if !all(mask)
s, t = s[mask], t[mask]
idxs = similar(s, num_edges)
idxs .= 1:num_edges
idxs .= idxs .- cumsum(.!mask)
num_edges = length(s)
w = _scatter(aggr, w, idxs, num_edges)
edata = _scatter(aggr, edata, idxs, num_edges)
end

mask_to_keep = trues(length(s))

mask_to_keep[edges_to_remove] .= false

s = s[mask_to_keep]
t = t[mask_to_keep]
edata = getobs(edata, mask_to_keep)
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)

return GNNGraph((s, t, w),
g.num_nodes, num_edges, g.num_graphs,
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end


function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5)
num_edges = g.num_edges
edges_to_remove = filter(_ -> rand() < p, 1:num_edges)
g = remove_edges(g, edges_to_remove)
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata
return GNNGraph((s, t, w),
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end
Expand Down
7 changes: 7 additions & 0 deletions GNNGraphs/test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ end
@test new_t == [4]
@test new_w == [0.3]
@test new_edata == ['c']

# drop with probability
gnew = remove_edges(g, Float32(1.0))
@test gnew.num_edges == 0

gnew = remove_edges(g, Float32(0.0))
@test gnew.num_edges == g.num_edges
end
end

Expand Down

0 comments on commit d2ab349

Please sign in to comment.